import sys import os import pickle import tempfile import logging from typing import Optional, List, Dict from datetime import datetime, timedelta import pandas as pd import numpy as np import httpx from fastapi import FastAPI, HTTPException, Query, BackgroundTasks, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.responses import FileResponse, JSONResponse from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException as StarletteHTTPException from pydantic import BaseModel from umap import UMAP from utils.data_loader import ModelDataLoader from utils.embeddings import ModelEmbedder from utils.dimensionality_reduction import DimensionReducer from utils.network_analysis import ModelNetworkBuilder from utils.graph_embeddings import GraphEmbedder from services.model_tracker import get_tracker from services.arxiv_api import extract_arxiv_ids, fetch_arxiv_papers from core.config import settings from core.exceptions import DataNotLoadedError, EmbeddingsNotReadyError from models.schemas import ModelPoint from utils.family_tree import calculate_family_depths from utils.cache import cache, cached_response from utils.response_encoder import FastJSONResponse, MessagePackResponse, encode_models_msgpack import api.dependencies as deps from api.routes import models, stats, clusters # Create aliases for backward compatibility with existing routes # Note: These are set at module load time and may be None initially # Functions should access via deps.* to get current values data_loader = deps.data_loader backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if backend_dir not in sys.path: sys.path.insert(0, backend_dir) logger = logging.getLogger(__name__) app = FastAPI(title="HF Model Ecosystem API", version="2.0.0") app.add_middleware(GZipMiddleware, minimum_size=1000) CORS_HEADERS = { "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "*", "Access-Control-Allow-Headers": "*", } @app.exception_handler(Exception) async def global_exception_handler(request: Request, exc: Exception): logger.exception("Unhandled exception", exc_info=exc) return JSONResponse( status_code=500, content={"detail": "Internal server error"}, headers=CORS_HEADERS, ) @app.exception_handler(StarletteHTTPException) async def http_exception_handler(request: Request, exc: StarletteHTTPException): return JSONResponse( status_code=exc.status_code, content={"detail": exc.detail}, headers=CORS_HEADERS, ) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): return JSONResponse( status_code=422, content={"detail": exc.errors()}, headers=CORS_HEADERS, ) if settings.ALLOW_ALL_ORIGINS: app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) else: app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:3000", settings.FRONTEND_URL], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Include routers app.include_router(models.router) app.include_router(stats.router) app.include_router(clusters.router) @app.on_event("startup") async def startup_event(): """ Fast startup using pre-computed data. Falls back to traditional loading if pre-computed data not available. """ import time startup_start = time.time() backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) root_dir = os.path.dirname(backend_dir) # Try to load pre-computed data first (instant startup!) from utils.precomputed_loader import get_precomputed_loader precomputed_loader = get_precomputed_loader(version="v1") if precomputed_loader: logger.info("=" * 60) logger.info("LOADING PRE-COMPUTED DATA (Fast Startup Mode)") logger.info("=" * 60) try: # Check if chunked embeddings are available is_chunked = precomputed_loader.is_chunked() # Load data - don't load embeddings if chunked (load on-demand instead) load_embeddings_at_startup = not is_chunked # Only load if not chunked deps.df, deps.embeddings, metadata = precomputed_loader.load_all( load_embeddings=load_embeddings_at_startup ) # Initialize chunked loader if chunked data is available if is_chunked: chunked_loader = precomputed_loader.get_chunked_loader() if chunked_loader: deps.chunked_embedding_loader = chunked_loader logger.info("Chunked embedding loader initialized - embeddings will be loaded on-demand") else: logger.warning("Chunked data detected but chunked loader unavailable - falling back to full load") # Fallback: try to load all embeddings deps.df, deps.embeddings, metadata = precomputed_loader.load_all(load_embeddings=True) # Extract 3D coordinates from dataframe deps.reduced_embeddings = np.column_stack([ deps.df['x_3d'].values, deps.df['y_3d'].values, deps.df['z_3d'].values ]) # Initialize embedder (without loading/generating embeddings) deps.embedder = ModelEmbedder() # Initialize reducer (already fitted) deps.reducer = DimensionReducer(method="umap", n_components=3) # No graph embeddings in fast mode (optional feature) deps.graph_embedder = None deps.graph_embeddings_dict = None deps.combined_embeddings = None deps.reduced_embeddings_graph = None startup_time = time.time() - startup_start logger.info("=" * 60) logger.info(f"STARTUP COMPLETE in {startup_time:.2f} seconds!") logger.info(f"Loaded {len(deps.df):,} models with pre-computed coordinates") if is_chunked: logger.info("Using chunked embeddings - fast startup mode enabled") logger.info(f"Unique libraries: {metadata.get('unique_libraries')}") logger.info(f"Unique pipelines: {metadata.get('unique_pipelines')}") logger.info("=" * 60) # Update module-level aliases df = deps.df embedder = deps.embedder reducer = deps.reducer embeddings = deps.embeddings reduced_embeddings = deps.reduced_embeddings return except Exception as e: logger.warning(f"Failed to load pre-computed data: {e}") logger.info("Falling back to traditional loading...") else: logger.info("=" * 60) logger.info("Pre-computed data not found.") logger.info("To enable fast startup, run:") logger.info(" cd backend && python scripts/precompute_data.py --sample-size 150000") logger.info("=" * 60) logger.info("Falling back to traditional loading (may take 1-8 hours)...") # Traditional loading (slow path) cache_dir = os.path.join(root_dir, "cache") os.makedirs(cache_dir, exist_ok=True) embeddings_cache = os.path.join(cache_dir, "embeddings.pkl") graph_embeddings_cache = os.path.join(cache_dir, "graph_embeddings.pkl") combined_embeddings_cache = os.path.join(cache_dir, "combined_embeddings.pkl") reduced_cache_umap = os.path.join(cache_dir, "reduced_umap_3d.pkl") reduced_cache_umap_graph = os.path.join(cache_dir, "reduced_umap_3d_graph.pkl") reducer_cache_umap = os.path.join(cache_dir, "reducer_umap_3d.pkl") reducer_cache_umap_graph = os.path.join(cache_dir, "reducer_umap_3d_graph.pkl") # Load dataset with sample (for reasonable startup time) sample_size = settings.SAMPLE_SIZE or settings.get_sample_size() or 5000 logger.info(f"Loading dataset (sample_size={sample_size}, prioritizing base models)...") deps.df = deps.data_loader.load_data(sample_size=sample_size, prioritize_base_models=True) deps.df = deps.data_loader.preprocess_for_embedding(deps.df) if 'model_id' in deps.df.columns: deps.df.set_index('model_id', drop=False, inplace=True) for col in ['downloads', 'likes']: if col in deps.df.columns: deps.df[col] = pd.to_numeric(deps.df[col], errors='coerce').fillna(0).astype(int) deps.embedder = ModelEmbedder() # Load or generate text embeddings if os.path.exists(embeddings_cache): try: deps.embeddings = deps.embedder.load_embeddings(embeddings_cache) except (IOError, pickle.UnpicklingError, EOFError) as e: logger.warning(f"Failed to load cached embeddings: {e}") deps.embeddings = None if deps.embeddings is None: texts = deps.df['combined_text'].tolist() deps.embeddings = deps.embedder.generate_embeddings(texts, batch_size=128) deps.embedder.save_embeddings(deps.embeddings, embeddings_cache) # Skip graph embeddings in fallback mode (too slow) deps.graph_embedder = None deps.graph_embeddings_dict = None deps.combined_embeddings = None # Initialize reducer for text embeddings deps.reducer = DimensionReducer(method="umap", n_components=3) # Pre-compute clusters for faster requests logger.info("Pre-computing clusters...") if os.path.exists(reduced_cache_umap) and os.path.exists(reducer_cache_umap): try: with open(reduced_cache_umap, 'rb') as f: deps.reduced_embeddings = pickle.load(f) deps.reducer.load_reducer(reducer_cache_umap) except (IOError, pickle.UnpicklingError, EOFError) as e: logger.warning(f"Failed to load cached reduced embeddings: {e}") deps.reduced_embeddings = None if deps.reduced_embeddings is None: deps.reducer.reducer = UMAP( n_components=3, n_neighbors=30, min_dist=0.3, metric='cosine', random_state=42, n_jobs=-1, low_memory=True, spread=1.5 ) deps.reduced_embeddings = deps.reducer.fit_transform(deps.embeddings) with open(reduced_cache_umap, 'wb') as f: pickle.dump(deps.reduced_embeddings, f) deps.reducer.save_reducer(reducer_cache_umap) # No graph embeddings in fallback mode deps.reduced_embeddings_graph = None # Pre-compute clusters now instead of on first request if deps.reduced_embeddings is not None and len(deps.reduced_embeddings) > 0: models.cluster_labels = compute_clusters( deps.reduced_embeddings, n_clusters=min(50, len(deps.reduced_embeddings) // 100) ) logger.info(f"Pre-computed {len(set(models.cluster_labels))} clusters") startup_time = time.time() - startup_start logger.info(f"Startup complete in {startup_time:.2f} seconds") # Update module-level aliases df = deps.df embedder = deps.embedder graph_embedder = deps.graph_embedder reducer = deps.reducer embeddings = deps.embeddings graph_embeddings_dict = deps.graph_embeddings_dict combined_embeddings = deps.combined_embeddings reduced_embeddings = deps.reduced_embeddings reduced_embeddings_graph = deps.reduced_embeddings_graph from utils.family_tree import calculate_family_depths def compute_clusters(reduced_embeddings: np.ndarray, n_clusters: int = 50) -> np.ndarray: from sklearn.cluster import KMeans n_samples = len(reduced_embeddings) if n_samples < n_clusters: n_clusters = max(1, n_samples // 10) kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) return kmeans.fit_predict(reduced_embeddings) @app.get("/") async def root(): # Check if frontend build exists and serve it _backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) _frontend_build_path = os.path.join(os.path.dirname(_backend_dir), "frontend", "build") index_path = os.path.join(_frontend_build_path, "index.html") if os.path.exists(index_path): from starlette.responses import FileResponse as StarletteFileResponse return StarletteFileResponse(index_path) # Fallback to API status when no frontend build return {"message": "HF Model Ecosystem API", "status": "running"} @app.get("/api/models") async def get_models( min_downloads: int = Query(0), min_likes: int = Query(0), search_query: Optional[str] = Query(None), color_by: str = Query("library_name"), size_by: str = Query("downloads"), max_points: Optional[int] = Query(10000), # REDUCED from None (was 50k default in frontend) projection_method: str = Query("umap"), base_models_only: bool = Query(False), max_hierarchy_depth: Optional[int] = Query(None, ge=0, description="Filter to models at or below this hierarchy depth."), use_graph_embeddings: bool = Query(False, description="Use graph-aware embeddings that respect family tree structure"), format: str = Query("json", regex="^(json|msgpack)$", description="Response format: json or msgpack") ): if deps.df is None: raise DataNotLoadedError() df = deps.df # Filter data filtered_df = data_loader.filter_data( df=df, min_downloads=min_downloads, min_likes=min_likes, search_query=search_query, libraries=None, # Can be added as query params pipeline_tags=None ) if base_models_only: if 'parent_model' in filtered_df.columns: filtered_df = filtered_df[ filtered_df['parent_model'].isna() | (filtered_df['parent_model'].astype(str).str.strip() == '') | (filtered_df['parent_model'].astype(str) == 'nan') ] if max_hierarchy_depth is not None: family_depths = calculate_family_depths(df) filtered_df = filtered_df[ filtered_df['model_id'].astype(str).map(lambda x: family_depths.get(x, 0) <= max_hierarchy_depth) ] filtered_count = len(filtered_df) if len(filtered_df) == 0: return { "models": [], "filtered_count": 0, "returned_count": 0 } # Handle max_points: None means no limit, very large number also means no limit effective_max_points = None if max_points is None or max_points >= 1000000 else max_points if effective_max_points is not None and len(filtered_df) > effective_max_points: if 'library_name' in filtered_df.columns and filtered_df['library_name'].notna().any(): # Sample proportionally by library, preserving all columns sampled_dfs = [] for lib_name, group in filtered_df.groupby('library_name', group_keys=False): n_samples = max(1, int(effective_max_points * len(group) / len(filtered_df))) sampled_dfs.append(group.sample(min(len(group), n_samples), random_state=42)) filtered_df = pd.concat(sampled_dfs, ignore_index=True) if len(filtered_df) > effective_max_points: filtered_df = filtered_df.sample(n=effective_max_points, random_state=42).reset_index(drop=True) else: filtered_df = filtered_df.reset_index(drop=True) else: filtered_df = filtered_df.sample(n=effective_max_points, random_state=42).reset_index(drop=True) # Determine which embeddings to use if use_graph_embeddings and combined_embeddings is not None: current_embeddings = combined_embeddings current_reduced = reduced_embeddings_graph embedding_type = "graph-aware" else: if embeddings is None: raise EmbeddingsNotReadyError() current_embeddings = embeddings current_reduced = reduced_embeddings embedding_type = "text-only" # Handle reduced embeddings loading/generation if current_reduced is None or (reducer and reducer.method != projection_method.lower()): backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) root_dir = os.path.dirname(backend_dir) cache_dir = os.path.join(root_dir, "cache") cache_suffix = "_graph" if use_graph_embeddings and combined_embeddings is not None else "" reduced_cache = os.path.join(cache_dir, f"reduced_{projection_method.lower()}_3d{cache_suffix}.pkl") reducer_cache = os.path.join(cache_dir, f"reducer_{projection_method.lower()}_3d{cache_suffix}.pkl") if os.path.exists(reduced_cache) and os.path.exists(reducer_cache): try: with open(reduced_cache, 'rb') as f: current_reduced = pickle.load(f) if reducer is None or reducer.method != projection_method.lower(): reducer = DimensionReducer(method=projection_method.lower(), n_components=3) reducer.load_reducer(reducer_cache) except (IOError, pickle.UnpicklingError, EOFError) as e: logger.warning(f"Failed to load cached reduced embeddings: {e}") current_reduced = None if current_reduced is None: if reducer is None or reducer.method != projection_method.lower(): reducer = DimensionReducer(method=projection_method.lower(), n_components=3) if projection_method.lower() == "umap": reducer.reducer = UMAP( n_components=3, n_neighbors=30, min_dist=0.3, metric='cosine', random_state=42, n_jobs=-1, low_memory=True, spread=1.5 ) current_reduced = reducer.fit_transform(current_embeddings) with open(reduced_cache, 'wb') as f: pickle.dump(current_reduced, f) reducer.save_reducer(reducer_cache) # Update global variable if use_graph_embeddings and deps.combined_embeddings is not None: deps.reduced_embeddings_graph = current_reduced else: deps.reduced_embeddings = current_reduced # Get indices for filtered data # Use model_id column to map between filtered_df and original df # This is safer than using index positions which can change after filtering filtered_model_ids = filtered_df['model_id'].astype(str).values # Map model_ids to positions in original df if df.index.name == 'model_id' or 'model_id' in df.index.names: # When df is indexed by model_id, use get_loc directly filtered_indices = [] for model_id in filtered_model_ids: try: pos = df.index.get_loc(model_id) # Handle both single position and array of positions if isinstance(pos, (int, np.integer)): filtered_indices.append(int(pos)) elif isinstance(pos, (slice, np.ndarray)): # If multiple matches, take first if isinstance(pos, slice): filtered_indices.append(int(pos.start)) else: filtered_indices.append(int(pos[0])) except (KeyError, TypeError): continue filtered_indices = np.array(filtered_indices, dtype=np.int32) else: # When df is not indexed by model_id, find positions by matching model_id column df_model_ids = df['model_id'].astype(str).values model_id_to_pos = {mid: pos for pos, mid in enumerate(df_model_ids)} filtered_indices = np.array([ model_id_to_pos[mid] for mid in filtered_model_ids if mid in model_id_to_pos ], dtype=np.int32) if len(filtered_indices) == 0: return { "models": [], "embedding_type": embedding_type, "filtered_count": filtered_count, "returned_count": 0 } filtered_reduced = current_reduced[filtered_indices] family_depths = calculate_family_depths(df) # Use appropriate embeddings for clustering clustering_embeddings = current_reduced # Compute clusters if not already computed or if size changed if models.cluster_labels is None or len(models.cluster_labels) != len(clustering_embeddings): models.cluster_labels = compute_clusters(clustering_embeddings, n_clusters=min(50, len(clustering_embeddings) // 100)) # Handle case where cluster_labels might not match filtered data yet if models.cluster_labels is not None and len(models.cluster_labels) > 0: if len(filtered_indices) <= len(models.cluster_labels): filtered_clusters = models.cluster_labels[filtered_indices] else: # Fallback: use first cluster for all if indices don't match filtered_clusters = np.zeros(len(filtered_indices), dtype=int) else: filtered_clusters = np.zeros(len(filtered_indices), dtype=int) model_ids = filtered_df['model_id'].astype(str).values library_names = filtered_df.get('library_name', pd.Series([None] * len(filtered_df))).values pipeline_tags = filtered_df.get('pipeline_tag', pd.Series([None] * len(filtered_df))).values downloads_arr = filtered_df.get('downloads', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values likes_arr = filtered_df.get('likes', pd.Series([0] * len(filtered_df))).fillna(0).astype(int).values trending_scores = filtered_df.get('trendingScore', pd.Series([None] * len(filtered_df))).values tags_arr = filtered_df.get('tags', pd.Series([None] * len(filtered_df))).values parent_models = filtered_df.get('parent_model', pd.Series([None] * len(filtered_df))).values licenses_arr = filtered_df.get('licenses', pd.Series([None] * len(filtered_df))).values created_at_arr = filtered_df.get('createdAt', pd.Series([None] * len(filtered_df))).values x_coords = filtered_reduced[:, 0].astype(float) y_coords = filtered_reduced[:, 1].astype(float) z_coords = filtered_reduced[:, 2].astype(float) if filtered_reduced.shape[1] > 2 else np.zeros(len(filtered_reduced), dtype=float) models = [ ModelPoint( model_id=model_ids[idx], x=float(x_coords[idx]), y=float(y_coords[idx]), z=float(z_coords[idx]), library_name=library_names[idx] if pd.notna(library_names[idx]) else None, pipeline_tag=pipeline_tags[idx] if pd.notna(pipeline_tags[idx]) else None, downloads=int(downloads_arr[idx]), likes=int(likes_arr[idx]), trending_score=float(trending_scores[idx]) if idx < len(trending_scores) and pd.notna(trending_scores[idx]) else None, tags=tags_arr[idx] if idx < len(tags_arr) and pd.notna(tags_arr[idx]) else None, parent_model=parent_models[idx] if idx < len(parent_models) and pd.notna(parent_models[idx]) else None, licenses=licenses_arr[idx] if idx < len(licenses_arr) and pd.notna(licenses_arr[idx]) else None, family_depth=family_depths.get(model_ids[idx], None), cluster_id=int(filtered_clusters[idx]) if idx < len(filtered_clusters) else None, created_at=str(created_at_arr[idx]) if idx < len(created_at_arr) and pd.notna(created_at_arr[idx]) else None ) for idx in range(len(filtered_df)) ] # Return models with metadata about embedding type response_data = { "models": models, "embedding_type": embedding_type, "filtered_count": filtered_count, "returned_count": len(models) } # Return in requested format with caching headers if format == "msgpack": try: binary_data = encode_models_msgpack([m.dict() for m in models]) return Response( content=binary_data, media_type="application/msgpack", headers={ "Cache-Control": "public, max-age=300", "X-Content-Type-Options": "nosniff", "Access-Control-Expose-Headers": "Cache-Control" } ) except Exception as e: logger.warning(f"MessagePack encoding failed, falling back to JSON: {e}") # Return JSON with caching headers return FastJSONResponse( content=response_data, headers={ "Cache-Control": "public, max-age=300", "X-Content-Type-Options": "nosniff", "Access-Control-Expose-Headers": "Cache-Control" } ) @app.get("/api/stats") async def get_stats(): """Get dataset statistics.""" if df is None: raise DataNotLoadedError() total_models = len(df.index) if hasattr(df, 'index') else len(df) # Get unique licenses with counts licenses = {} if 'license' in df.columns: license_counts = df['license'].value_counts().to_dict() licenses = {str(k): int(v) for k, v in license_counts.items() if pd.notna(k) and str(k) != 'nan'} return { "total_models": total_models, "unique_libraries": int(df['library_name'].nunique()) if 'library_name' in df.columns else 0, "unique_pipelines": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0, "unique_task_types": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0, # Alias for clarity "unique_licenses": len(licenses), "licenses": licenses, # License name -> count mapping "avg_downloads": float(df['downloads'].mean()) if 'downloads' in df.columns else 0, "avg_likes": float(df['likes'].mean()) if 'likes' in df.columns else 0 } @app.get("/api/model/{model_id}") async def get_model_details(model_id: str): """Get detailed information about a specific model.""" if df is None: raise DataNotLoadedError() model = df[df.get('model_id', '') == model_id] if len(model) == 0: raise HTTPException(status_code=404, detail="Model not found") model = model.iloc[0] tags_str = str(model.get('tags', '')) if pd.notna(model.get('tags')) else '' arxiv_ids = extract_arxiv_ids(tags_str) papers = [] if arxiv_ids: papers = await fetch_arxiv_papers(arxiv_ids[:5]) # Limit to 5 papers return { "model_id": model.get('model_id'), "library_name": model.get('library_name'), "pipeline_tag": model.get('pipeline_tag'), "downloads": int(model.get('downloads', 0)), "likes": int(model.get('likes', 0)), "trending_score": float(model.get('trendingScore', 0)) if pd.notna(model.get('trendingScore')) else None, "tags": model.get('tags') if pd.notna(model.get('tags')) else None, "licenses": model.get('licenses') if pd.notna(model.get('licenses')) else None, "parent_model": model.get('parent_model') if pd.notna(model.get('parent_model')) else None, "arxiv_papers": papers, "arxiv_ids": arxiv_ids } # Clusters endpoint is handled by routes/clusters.py router @app.get("/api/family/stats") async def get_family_stats(): """ Get aggregate statistics about family trees for paper visualizations. Returns family size distribution, depth statistics, model card length by depth, etc. """ if df is None: raise DataNotLoadedError() family_sizes = {} root_models = set() for idx, row in df.iterrows(): model_id = str(row.get('model_id', '')) parent_id = row.get('parent_model') if pd.isna(parent_id) or str(parent_id) == 'nan' or str(parent_id) == '': root_models.add(model_id) if model_id not in family_sizes: family_sizes[model_id] = 0 else: parent_id_str = str(parent_id) root = parent_id_str visited = set() while root in df.index and pd.notna(df.loc[root].get('parent_model')): parent = df.loc[root].get('parent_model') if pd.isna(parent) or str(parent) == 'nan' or str(parent) == '': break if str(parent) in visited: break visited.add(root) root = str(parent) if root not in family_sizes: family_sizes[root] = 0 family_sizes[root] += 1 size_distribution = {} for root, size in family_sizes.items(): size_distribution[size] = size_distribution.get(size, 0) + 1 depths = calculate_family_depths(df) depth_counts = {} for depth in depths.values(): depth_counts[depth] = depth_counts.get(depth, 0) + 1 model_card_lengths_by_depth = {} if 'modelCard' in df.columns: for idx, row in df.iterrows(): model_id = str(row.get('model_id', '')) depth = depths.get(model_id, 0) model_card = row.get('modelCard', '') if pd.notna(model_card): card_length = len(str(model_card)) if depth not in model_card_lengths_by_depth: model_card_lengths_by_depth[depth] = [] model_card_lengths_by_depth[depth].append(card_length) model_card_stats = {} for depth, lengths in model_card_lengths_by_depth.items(): if lengths: model_card_stats[depth] = { "mean": float(np.mean(lengths)), "median": float(np.median(lengths)), "q1": float(np.percentile(lengths, 25)), "q3": float(np.percentile(lengths, 75)), "min": float(np.min(lengths)), "max": float(np.max(lengths)), "count": len(lengths) } return { "total_families": len(root_models), "family_size_distribution": size_distribution, "depth_distribution": depth_counts, "max_family_size": max(family_sizes.values()) if family_sizes else 0, "max_depth": max(depths.values()) if depths else 0, "avg_family_size": sum(family_sizes.values()) / len(family_sizes) if family_sizes else 0, "model_card_length_by_depth": model_card_stats } @app.get("/api/family/top") async def get_top_families( limit: int = Query(50, ge=1, le=200, description="Maximum number of families to return"), min_size: int = Query(2, ge=1, description="Minimum family size to include") ): """ Get top families by total lineage count (sum of all descendants). Calculates the actual family tree size by traversing parent-child relationships. """ if deps.df is None: raise DataNotLoadedError() df = deps.df # Build parent -> children mapping children_map = {} root_models = set() for idx, row in df.iterrows(): model_id = str(row.get('model_id', '')) parent_id = row.get('parent_model') if pd.isna(parent_id) or str(parent_id) == 'nan' or str(parent_id) == '': root_models.add(model_id) else: parent_str = str(parent_id) if parent_str not in children_map: children_map[parent_str] = [] children_map[parent_str].append(model_id) # For each root, count all descendants def count_descendants(model_id: str, visited: set) -> int: if model_id in visited: return 0 visited.add(model_id) count = 1 # Count self for child in children_map.get(model_id, []): count += count_descendants(child, visited) return count # Calculate family sizes family_data = [] for root in root_models: visited = set() total_count = count_descendants(root, visited) if total_count >= min_size: # Get organization from model_id org = root.split('/')[0] if '/' in root else root family_data.append({ "root_model": root, "organization": org, "total_models": total_count, "depth_count": len(visited) # Same as total for tree traversal }) # Sort by total count descending family_data.sort(key=lambda x: x['total_models'], reverse=True) # Also aggregate by organization (sum all families under same org) org_totals = {} for fam in family_data: org = fam['organization'] if org not in org_totals: org_totals[org] = { "organization": org, "total_models": 0, "family_count": 0, "root_models": [] } org_totals[org]['total_models'] += fam['total_models'] org_totals[org]['family_count'] += 1 if len(org_totals[org]['root_models']) < 5: # Keep top 5 root models org_totals[org]['root_models'].append(fam['root_model']) # Sort organizations by total models top_orgs = sorted(org_totals.values(), key=lambda x: x['total_models'], reverse=True)[:limit] return { "families": family_data[:limit], "organizations": top_orgs, "total_families": len(family_data), "total_root_models": len(root_models) } @app.get("/api/family/path/{model_id}") async def get_family_path( model_id: str, target_id: Optional[str] = Query(None, description="Target model ID. If None, returns path to root.") ): """ Get path from model to root or to target model. Returns list of model IDs representing the path. """ if df is None: raise DataNotLoadedError() model_id_str = str(model_id) if df.index.name == 'model_id': if model_id_str not in df.index: raise HTTPException(status_code=404, detail="Model not found") else: model_rows = df[df.get('model_id', '') == model_id_str] if len(model_rows) == 0: raise HTTPException(status_code=404, detail="Model not found") path = [model_id_str] visited = set([model_id_str]) current = model_id_str if target_id: target_str = str(target_id) if df.index.name == 'model_id': if target_str not in df.index: raise HTTPException(status_code=404, detail="Target model not found") while current != target_str and current not in visited: try: if df.index.name == 'model_id': row = df.loc[current] else: rows = df[df.get('model_id', '') == current] if len(rows) == 0: break row = rows.iloc[0] parent_id = row.get('parent_model') if parent_id and pd.notna(parent_id): parent_str = str(parent_id) if parent_str == target_str: path.append(parent_str) break if parent_str not in visited: path.append(parent_str) visited.add(parent_str) current = parent_str else: break else: break except (KeyError, IndexError): break else: while True: try: if df.index.name == 'model_id': row = df.loc[current] else: rows = df[df.get('model_id', '') == current] if len(rows) == 0: break row = rows.iloc[0] parent_id = row.get('parent_model') if parent_id and pd.notna(parent_id): parent_str = str(parent_id) if parent_str not in visited: path.append(parent_str) visited.add(parent_str) current = parent_str else: break else: break except (KeyError, IndexError): break return { "path": path, "source": model_id_str, "target": target_id if target_id else "root", "path_length": len(path) - 1 } @app.get("/api/family/{model_id}") async def get_family_tree( model_id: str, max_depth: Optional[int] = Query(None, ge=1, le=100, description="Maximum depth to traverse. If None, traverses entire tree without limit."), max_depth_filter: Optional[int] = Query(None, ge=0, description="Filter results to models at or below this hierarchy depth.") ): """ Get family tree for a model (ancestors and descendants). Returns the model, its parent chain, and all children. If max_depth is None, traverses the entire family tree without depth limits. """ if df is None: raise DataNotLoadedError() if reduced_embeddings is None: raise HTTPException(status_code=503, detail="Embeddings not ready") model_id_str = str(model_id) if df.index.name == 'model_id': if model_id_str not in df.index: raise HTTPException(status_code=404, detail="Model not found") model_lookup = df.loc else: model_rows = df[df.get('model_id', '') == model_id_str] if len(model_rows) == 0: raise HTTPException(status_code=404, detail="Model not found") model_lookup = lambda x: df[df.get('model_id', '') == x] from utils.network_analysis import _get_all_parents, _parse_parent_list children_index: Dict[str, List[str]] = {} parent_columns = ['parent_model', 'finetune_parent', 'quantized_parent', 'adapter_parent', 'merge_parent'] for idx, row in df.iterrows(): model_id_from_row = str(row.get('model_id', idx)) all_parents = _get_all_parents(row) for rel_type, parent_list in all_parents.items(): for parent_str in parent_list: if parent_str not in children_index: children_index[parent_str] = [] children_index[parent_str].append(model_id_from_row) visited = set() def get_ancestors(current_id: str, depth: Optional[int]): if current_id in visited: return if depth is not None and depth <= 0: return visited.add(current_id) try: if df.index.name == 'model_id': row = df.loc[current_id] else: rows = model_lookup(current_id) if len(rows) == 0: return row = rows.iloc[0] all_parents = _get_all_parents(row) for rel_type, parent_list in all_parents.items(): for parent_str in parent_list: if parent_str != 'nan' and parent_str != '': next_depth = depth - 1 if depth is not None else None get_ancestors(parent_str, next_depth) except (KeyError, IndexError): return def get_descendants(current_id: str, depth: Optional[int]): if current_id in visited: return if depth is not None and depth <= 0: return visited.add(current_id) children = children_index.get(current_id, []) for child_id in children: if child_id not in visited: next_depth = depth - 1 if depth is not None else None get_descendants(child_id, next_depth) get_ancestors(model_id_str, max_depth) visited = set() get_descendants(model_id_str, max_depth) visited.add(model_id_str) if df.index.name == 'model_id': try: family_df = df.loc[list(visited)] except KeyError: missing = [v for v in visited if v not in df.index] if missing: logger.warning(f"Some family members not found in index: {missing}") family_df = df.loc[[v for v in visited if v in df.index]] else: family_df = df[df.get('model_id', '').isin(visited)] if len(family_df) == 0: raise HTTPException(status_code=404, detail="Family tree data not available") family_indices = family_df.index.values if len(family_indices) > len(reduced_embeddings): raise HTTPException(status_code=503, detail="Embedding indices mismatch") family_reduced = reduced_embeddings[family_indices] family_map = {} for idx, (i, row) in enumerate(family_df.iterrows()): model_id_val = str(row.get('model_id', i)) parent_id = row.get('parent_model') parent_id_str = str(parent_id) if parent_id and pd.notna(parent_id) else None depths = calculate_family_depths(df) model_depth = depths.get(model_id_val, 0) if max_depth_filter is not None and model_depth > max_depth_filter: continue family_map[model_id_val] = { "model_id": model_id_val, "x": float(family_reduced[idx, 0]), "y": float(family_reduced[idx, 1]), "z": float(family_reduced[idx, 2]) if family_reduced.shape[1] > 2 else 0.0, "library_name": str(row.get('library_name')) if pd.notna(row.get('library_name')) else None, "pipeline_tag": str(row.get('pipeline_tag')) if pd.notna(row.get('pipeline_tag')) else None, "downloads": int(row.get('downloads', 0)) if pd.notna(row.get('downloads')) else 0, "likes": int(row.get('likes', 0)) if pd.notna(row.get('likes')) else 0, "parent_model": parent_id_str, "licenses": str(row.get('licenses')) if pd.notna(row.get('licenses')) else None, "family_depth": model_depth, "children": [] } root_models = [] for model_id_val, model_data in family_map.items(): parent_id = model_data["parent_model"] if parent_id and parent_id in family_map: family_map[parent_id]["children"].append(model_id_val) else: root_models.append(model_id_val) return { "root_model": model_id_str, "family": list(family_map.values()), "family_map": family_map, "root_models": root_models } @app.get("/api/search") async def search_models( q: str = Query(..., min_length=1, alias="query"), query: str = Query(None, min_length=1), limit: int = Query(20, ge=1, le=100), graph_aware: bool = Query(False), include_neighbors: bool = Query(True) ): """ Search for models by name (for autocomplete and family tree lookup). Enhanced with graph-aware search option that includes network relationships. """ if df is None: raise DataNotLoadedError() # Support both 'q' and 'query' parameters search_query = query or q if graph_aware: try: network_builder = ModelNetworkBuilder(df) top_models = network_builder.get_top_models_by_field(n=1000) model_ids = [mid for mid, _ in top_models] graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined') results = network_builder.search_graph_aware( query=search_query, graph=graph, max_results=limit, include_neighbors=include_neighbors ) return {"results": results, "search_type": "graph_aware", "query": search_query} except (ValueError, KeyError, AttributeError) as e: logger.warning(f"Graph-aware search failed, falling back to basic search: {e}") query_lower = search_query.lower() # Enhanced search: search model_id, org, tags, library, pipeline model_id_col = df.get('model_id', '').astype(str).str.lower() library_col = df.get('library_name', '').astype(str).str.lower() pipeline_col = df.get('pipeline_tag', '').astype(str).str.lower() tags_col = df.get('tags', '').astype(str).str.lower() license_col = df.get('license', '').astype(str).str.lower() # Extract org from model_id org_col = model_id_col.str.split('/').str[0] # Multi-field search mask = ( model_id_col.str.contains(query_lower, na=False) | org_col.str.contains(query_lower, na=False) | library_col.str.contains(query_lower, na=False) | pipeline_col.str.contains(query_lower, na=False) | tags_col.str.contains(query_lower, na=False) | license_col.str.contains(query_lower, na=False) ) matches = df[mask].head(limit) results = [] for _, row in matches.iterrows(): model_id = str(row.get('model_id', '')) org = model_id.split('/')[0] if '/' in model_id else '' # Get coordinates if available x = float(row.get('x', 0.0)) if 'x' in row else None y = float(row.get('y', 0.0)) if 'y' in row else None z = float(row.get('z', 0.0)) if 'z' in row else None results.append({ "model_id": model_id, "x": x, "y": y, "z": z, "org": org, "library": row.get('library_name'), "pipeline": row.get('pipeline_tag'), "license": row.get('license') if pd.notna(row.get('license')) else None, "downloads": int(row.get('downloads', 0)), "likes": int(row.get('likes', 0)), "parent_model": row.get('parent_model') if pd.notna(row.get('parent_model')) else None, "match_type": "direct" }) return {"results": results, "search_type": "basic", "query": search_query} @app.get("/api/search/fuzzy") async def fuzzy_search_models( q: str = Query(..., min_length=2, description="Search query"), limit: int = Query(50, ge=1, le=200, description="Maximum number of results"), threshold: int = Query(60, ge=0, le=100, description="Minimum fuzzy match score (0-100)"), ): """ Fuzzy search for models using rapidfuzz. Handles typos and partial matches across model names, libraries, and pipelines. Returns results sorted by relevance score. """ if deps.df is None: raise DataNotLoadedError() df = deps.df try: from rapidfuzz import fuzz, process from rapidfuzz.utils import default_process query_lower = q.lower().strip() # Prepare choices - combine model_id, library, and pipeline for searching # Create a searchable string for each model model_ids = df['model_id'].astype(str).tolist() libraries = df.get('library_name', pd.Series([''] * len(df))).fillna('').astype(str).tolist() pipelines = df.get('pipeline_tag', pd.Series([''] * len(df))).fillna('').astype(str).tolist() # Create search strings - just model_id for better fuzzy matching # Library and pipeline are used for secondary filtering search_strings = [m.lower() for m in model_ids] # Use rapidfuzz to find best matches # WRatio is best for general fuzzy matching with typo tolerance # It handles transpositions, insertions, deletions well # extract returns list of (match, score, index) matches = process.extract( query_lower, search_strings, scorer=fuzz.WRatio, limit=limit * 3, # Get extra to filter by threshold and dedupe score_cutoff=threshold, processor=default_process ) # Also try partial matching for substring searches if len(matches) < limit: partial_matches = process.extract( query_lower, search_strings, scorer=fuzz.partial_ratio, limit=limit * 2, score_cutoff=threshold + 10, # Higher threshold for partial processor=default_process ) # Add unique partial matches seen_indices = {m[2] for m in matches} for m in partial_matches: if m[2] not in seen_indices: matches.append(m) seen_indices.add(m[2]) results = [] seen_ids = set() for match_str, score, idx in matches: if len(results) >= limit: break model_id = model_ids[idx] if model_id in seen_ids: continue seen_ids.add(model_id) row = df.iloc[idx] # Get coordinates x = float(row.get('x', 0.0)) if 'x' in row else None y = float(row.get('y', 0.0)) if 'y' in row else None z = float(row.get('z', 0.0)) if 'z' in row else None results.append({ "model_id": model_id, "x": x, "y": y, "z": z, "score": round(score, 1), "library": row.get('library_name') if pd.notna(row.get('library_name')) else None, "pipeline": row.get('pipeline_tag') if pd.notna(row.get('pipeline_tag')) else None, "downloads": int(row.get('downloads', 0)), "likes": int(row.get('likes', 0)), "family_depth": int(row.get('family_depth', 0)) if pd.notna(row.get('family_depth')) else None, }) # Sort by score descending, then by downloads for tie-breaking results.sort(key=lambda x: (-x['score'], -x['downloads'])) return { "results": results, "query": q, "total_matches": len(matches), "threshold": threshold } except ImportError: raise HTTPException(status_code=500, detail="rapidfuzz not installed") except Exception as e: logger.exception(f"Fuzzy search error: {e}") raise HTTPException(status_code=500, detail=f"Search error: {str(e)}") @app.get("/api/similar/{model_id}") async def get_similar_models(model_id: str, k: int = Query(10, ge=1, le=50)): """ Get k-nearest neighbors of a model based on embedding similarity. Returns similar models with distance scores. """ if deps.df is None or deps.embeddings is None: raise HTTPException(status_code=503, detail="Data not loaded") df = deps.df embeddings = deps.embeddings if 'model_id' in df.index.names or df.index.name == 'model_id': try: model_row = df.loc[[model_id]] model_idx = model_row.index[0] except KeyError: raise HTTPException(status_code=404, detail="Model not found") else: model_row = df[df.get('model_id', '') == model_id] if len(model_row) == 0: raise HTTPException(status_code=404, detail="Model not found") model_idx = model_row.index[0] model_embedding = embeddings[model_idx] from sklearn.metrics.pairwise import cosine_similarity model_embedding_2d = model_embedding.reshape(1, -1) similarities = cosine_similarity(model_embedding_2d, embeddings)[0] top_k_indices = np.argpartition(similarities, -k-1)[-k-1:-1] top_k_indices = top_k_indices[np.argsort(similarities[top_k_indices])][::-1] similar_models = [] for idx in top_k_indices: if idx == model_idx: continue row = df.iloc[idx] similar_models.append({ "model_id": row.get('model_id', 'Unknown'), "similarity": float(similarities[idx]), "distance": float(1 - similarities[idx]), "library_name": row.get('library_name'), "pipeline_tag": row.get('pipeline_tag'), "downloads": int(row.get('downloads', 0)), "likes": int(row.get('likes', 0)), }) return { "query_model": model_id, "similar_models": similar_models } @app.get("/api/models/semantic-similarity") async def get_models_by_semantic_similarity( query_model_id: str = Query(...), k: int = Query(100, ge=1, le=1000), min_downloads: int = Query(0), min_likes: int = Query(0), projection_method: str = Query("umap") ): """ Get models sorted by semantic similarity to a query model. Returns models with their similarity scores and coordinates. Useful for exploring the embedding space around a specific model. """ if deps.df is None or deps.embeddings is None: raise HTTPException(status_code=503, detail="Data not loaded") df = deps.df embeddings = deps.embeddings # Find the query model if 'model_id' in df.index.names or df.index.name == 'model_id': try: model_row = df.loc[[query_model_id]] model_idx = model_row.index[0] except KeyError: raise HTTPException(status_code=404, detail="Query model not found") else: model_row = df[df.get('model_id', '') == query_model_id] if len(model_row) == 0: raise HTTPException(status_code=404, detail="Query model not found") model_idx = model_row.index[0] query_embedding = embeddings[model_idx] filtered_df = data_loader.filter_data( df=df, min_downloads=min_downloads, min_likes=min_likes, search_query=None, libraries=None, pipeline_tags=None ) if df.index.name == 'model_id' or 'model_id' in df.index.names: filtered_indices = [df.index.get_loc(idx) for idx in filtered_df.index] filtered_indices = np.array(filtered_indices, dtype=int) else: filtered_indices = filtered_df.index.values.astype(int) filtered_embeddings = embeddings[filtered_indices] from sklearn.metrics.pairwise import cosine_similarity query_embedding_2d = query_embedding.reshape(1, -1) similarities = cosine_similarity(query_embedding_2d, filtered_embeddings)[0] top_k_local_indices = np.argpartition(similarities, -k)[-k:] top_k_local_indices = top_k_local_indices[np.argsort(similarities[top_k_local_indices])][::-1] if reduced_embeddings is None: raise HTTPException(status_code=503, detail="Reduced embeddings not ready") top_k_original_indices = filtered_indices[top_k_local_indices] top_k_reduced = reduced_embeddings[top_k_original_indices] similar_models = [] for i, orig_idx in enumerate(top_k_original_indices): row = df.iloc[orig_idx] local_idx = top_k_local_indices[i] similar_models.append({ "model_id": str(row.get('model_id', 'Unknown')), "x": float(top_k_reduced[i, 0]), "y": float(top_k_reduced[i, 1]), "z": float(top_k_reduced[i, 2]) if top_k_reduced.shape[1] > 2 else 0.0, "similarity": float(similarities[local_idx]), "distance": float(1 - similarities[local_idx]), "library_name": str(row.get('library_name')) if pd.notna(row.get('library_name')) else None, "pipeline_tag": str(row.get('pipeline_tag')) if pd.notna(row.get('pipeline_tag')) else None, "downloads": int(row.get('downloads', 0)), "likes": int(row.get('likes', 0)), "trending_score": float(row.get('trendingScore')) if pd.notna(row.get('trendingScore')) else None, "tags": str(row.get('tags')) if pd.notna(row.get('tags')) else None, "parent_model": str(row.get('parent_model')) if pd.notna(row.get('parent_model')) else None, "licenses": str(row.get('licenses')) if pd.notna(row.get('licenses')) else None, }) return { "query_model": query_model_id, "models": similar_models, "count": len(similar_models) } @app.get("/api/distance") async def get_distance( model_id_1: str = Query(...), model_id_2: str = Query(...) ): """ Calculate distance/similarity between two models. """ if deps.df is None or deps.embeddings is None: raise HTTPException(status_code=503, detail="Data not loaded") df = deps.df embeddings = deps.embeddings # Find both models - optimized with index lookup if 'model_id' in df.index.names or df.index.name == 'model_id': try: model1_row = df.loc[[model_id_1]] model2_row = df.loc[[model_id_2]] idx1 = model1_row.index[0] idx2 = model2_row.index[0] except KeyError: raise HTTPException(status_code=404, detail="One or both models not found") else: model1_row = df[df.get('model_id', '') == model_id_1] model2_row = df[df.get('model_id', '') == model_id_2] if len(model1_row) == 0 or len(model2_row) == 0: raise HTTPException(status_code=404, detail="One or both models not found") idx1 = model1_row.index[0] idx2 = model2_row.index[0] from sklearn.metrics.pairwise import cosine_similarity similarity = cosine_similarity([embeddings[idx1]], [embeddings[idx2]])[0][0] distance = 1 - similarity return { "model_1": model_id_1, "model_2": model_id_2, "cosine_similarity": float(similarity), "cosine_distance": float(distance), "euclidean_distance": float(np.linalg.norm(embeddings[idx1] - embeddings[idx2])) } @app.post("/api/export") async def export_models(model_ids: List[str]): """ Export selected models as JSON with full metadata. """ if df is None: raise DataNotLoadedError() # Optimized export with index lookup if 'model_id' in df.index.names or df.index.name == 'model_id': try: exported = df.loc[model_ids] except KeyError: # Fallback if some IDs not in index exported = df[df.get('model_id', '').isin(model_ids)] else: exported = df[df.get('model_id', '').isin(model_ids)] if len(exported) == 0: return {"models": []} models = [ { "model_id": str(row.get('model_id', '')), "library_name": str(row.get('library_name')) if pd.notna(row.get('library_name')) else None, "pipeline_tag": str(row.get('pipeline_tag')) if pd.notna(row.get('pipeline_tag')) else None, "downloads": int(row.get('downloads', 0)) if pd.notna(row.get('downloads')) else 0, "likes": int(row.get('likes', 0)) if pd.notna(row.get('likes')) else 0, "trending_score": float(row.get('trendingScore', 0)) if pd.notna(row.get('trendingScore')) else None, "tags": str(row.get('tags')) if pd.notna(row.get('tags')) else None, "licenses": str(row.get('licenses')) if pd.notna(row.get('licenses')) else None, "parent_model": str(row.get('parent_model')) if pd.notna(row.get('parent_model')) else None, } for _, row in exported.iterrows() ] return { "count": len(models), "models": models } @app.get("/api/network/cooccurrence") async def get_cooccurrence_network( library: Optional[str] = Query(None), pipeline_tag: Optional[str] = Query(None), min_downloads: int = Query(0), min_likes: int = Query(0), n: int = Query(100, ge=1, le=1000), cooccurrence_method: str = Query("combined", regex="^(parent_family|library|pipeline|tags|combined)$") ): """ Build co-occurrence network for models (inspired by Open Syllabus Project). Connects models that appear together in same contexts (parent family, library, pipeline, tags). Returns network graph data suitable for visualization. """ if df is None: raise DataNotLoadedError() try: network_builder = ModelNetworkBuilder(df) top_models = network_builder.get_top_models_by_field( library=library, pipeline_tag=pipeline_tag, min_downloads=min_downloads, min_likes=min_likes, n=n ) if not top_models: return { "nodes": [], "links": [], "statistics": {} } model_ids = [mid for mid, _ in top_models] graph = network_builder.build_cooccurrence_network( model_ids=model_ids, cooccurrence_method=cooccurrence_method ) nodes = [] for node_id, attrs in graph.nodes(data=True): nodes.append({ "id": node_id, "title": attrs.get('title', node_id), "author": attrs.get('author', ''), "freq": attrs.get('freq', 0), "likes": attrs.get('likes', 0), "library": attrs.get('library', ''), "pipeline": attrs.get('pipeline', '') }) links = [] for source, target, attrs in graph.edges(data=True): links.append({ "source": source, "target": target, "weight": attrs.get('weight', 1) }) stats = network_builder.get_network_statistics(graph) return { "nodes": nodes, "links": links, "statistics": stats } except (ValueError, KeyError, AttributeError) as e: logger.error(f"Error building network: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Error building network: {str(e)}") @app.get("/api/network/family/{model_id}") async def get_family_network( model_id: str, max_depth: Optional[int] = Query(None, ge=1, le=100, description="Maximum depth to traverse. If None, traverses entire tree without limit."), edge_types: Optional[str] = Query(None, description="Comma-separated list of edge types to include (finetune,quantized,adapter,merge,parent). If None, includes all types."), include_edge_attributes: bool = Query(True, description="Whether to include edge attributes (change in likes, downloads, etc.)") ): """ Build family tree network for a model (directed graph). Returns network graph data showing parent-child relationships with multiple relationship types. Supports filtering by edge type (finetune, quantized, adapter, merge, parent). """ if df is None: raise DataNotLoadedError() try: filter_types = None if edge_types: filter_types = [t.strip() for t in edge_types.split(',') if t.strip()] network_builder = ModelNetworkBuilder(df) graph = network_builder.build_family_tree_network( root_model_id=model_id, max_depth=max_depth, include_edge_attributes=include_edge_attributes, filter_edge_types=filter_types ) nodes = [] for node_id, attrs in graph.nodes(data=True): nodes.append({ "id": node_id, "title": attrs.get('title', node_id), "freq": attrs.get('freq', 0), "likes": attrs.get('likes', 0), "downloads": attrs.get('downloads', 0), "library": attrs.get('library', ''), "pipeline": attrs.get('pipeline', '') }) links = [] for source, target, edge_attrs in graph.edges(data=True): link_data = { "source": source, "target": target, "edge_type": edge_attrs.get('edge_type'), "edge_types": edge_attrs.get('edge_types', []) } if include_edge_attributes: link_data.update({ "change_in_likes": edge_attrs.get('change_in_likes'), "percentage_change_in_likes": edge_attrs.get('percentage_change_in_likes'), "change_in_downloads": edge_attrs.get('change_in_downloads'), "percentage_change_in_downloads": edge_attrs.get('percentage_change_in_downloads'), "change_in_createdAt_days": edge_attrs.get('change_in_createdAt_days') }) links.append(link_data) stats = network_builder.get_network_statistics(graph) return { "nodes": nodes, "links": links, "statistics": stats, "root_model": model_id } except (ValueError, KeyError, AttributeError) as e: logger.error(f"Error building family network: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Error building family network: {str(e)}") @app.get("/api/network/full-derivatives") @cached_response(ttl=3600, key_prefix="full_derivatives_network") async def get_full_derivative_network( edge_types: Optional[str] = Query(None, description="Comma-separated list of edge types to include (finetune,quantized,adapter,merge,parent). If None, includes all types."), include_edge_attributes: bool = Query(False, description="Whether to include edge attributes (change in likes, downloads, etc.). Default False for performance."), include_positions: bool = Query(True, description="Whether to include pre-computed 3D positions for each node. Default True for faster rendering."), min_downloads: int = Query(0, description="Minimum downloads to include a model. Use this to reduce network size."), max_nodes: Optional[int] = Query(None, ge=100, le=1000000, description="Maximum number of nodes to include. Models are sorted by downloads. Use this to reduce network size."), use_precomputed: bool = Query(True, description="Try to load pre-computed network graph from disk if available.") ): """ Build full derivative relationship network for ALL models in the database. Returns a non-embedding based force-directed graph where edges represent derivative types. This computes over every single model in the database. Note: Edge attributes are disabled by default for performance with large datasets. If pre-computed positions exist, they will be included in the response. """ if deps.df is None or deps.df.empty: raise HTTPException( status_code=503, detail="Model data not loaded. Please wait for the server to finish loading data." ) try: import time import networkx as nx start_time = time.time() # Check if dataframe has required columns required_columns = ['model_id'] missing_columns = [col for col in required_columns if col not in deps.df.columns] if missing_columns: raise HTTPException( status_code=500, detail=f"Missing required columns: {missing_columns}" ) # Try to load pre-computed network graph graph = None if use_precomputed: try: backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) root_dir = os.path.dirname(backend_dir) precomputed_dir = os.path.join(root_dir, "precomputed_data") graph_file = os.path.join(precomputed_dir, "full_derivative_network.pkl") # Try to download from HF Hub if not found locally (for Spaces deployment) if not os.path.exists(graph_file): logger.info("Pre-computed network not found locally. Attempting to download from HF Hub...") from utils.precomputed_loader import download_network_from_hf_hub download_network_from_hf_hub(precomputed_dir, version="v1") if os.path.exists(graph_file): logger.info(f"Loading pre-computed network graph from {graph_file}...") with open(graph_file, 'rb') as f: graph = pickle.load(f) logger.info(f"Loaded pre-computed graph: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges") else: logger.info("Pre-computed network graph not available. Will build from scratch.") except Exception as e: logger.warning(f"Could not load pre-computed network graph: {e}. Will build from scratch.") # Filter dataframe if needed filtered_df = deps.df.copy() if min_downloads > 0: filtered_df = filtered_df[filtered_df.get('downloads', 0) >= min_downloads] logger.info(f"Filtered to {len(filtered_df):,} models with >= {min_downloads} downloads") if max_nodes and len(filtered_df) > max_nodes: # Sort by downloads and take top N filtered_df = filtered_df.nlargest(max_nodes, 'downloads', keep='first') logger.info(f"Limited to top {max_nodes:,} models by downloads") logger.info(f"Building full derivative network for {len(filtered_df):,} models...") filter_types = None if edge_types: filter_types = [t.strip() for t in edge_types.split(',') if t.strip()] # Build graph if not loaded from disk if graph is None: try: network_builder = ModelNetworkBuilder(filtered_df) logger.info("Calling build_full_derivative_network...") # Disable edge attributes for very large graphs to improve performance # They can be slow to compute for 100k+ edges graph = network_builder.build_full_derivative_network( include_edge_attributes=include_edge_attributes, filter_edge_types=filter_types ) except Exception as build_error: logger.error(f"Error in build_full_derivative_network: {build_error}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to build network graph: {str(build_error)}" ) else: # Filter pre-computed graph if needed if filter_types: # Remove edges that don't match filter edges_to_remove = [] for source, target, attrs in graph.edges(data=True): edge_types_list = attrs.get('edge_types', []) if not isinstance(edge_types_list, list): edge_types_list = [edge_types_list] if edge_types_list else [] if not any(et in filter_types for et in edge_types_list): edges_to_remove.append((source, target)) graph.remove_edges_from(edges_to_remove) # Remove isolated nodes isolated = list(nx.isolates(graph)) graph.remove_nodes_from(isolated) logger.info(f"Filtered graph: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges") # Filter nodes by downloads if needed if min_downloads > 0 or max_nodes: nodes_to_remove = [] for node_id in graph.nodes(): if node_id in filtered_df.index: continue nodes_to_remove.append(node_id) graph.remove_nodes_from(nodes_to_remove) isolated = list(nx.isolates(graph)) graph.remove_nodes_from(isolated) logger.info(f"Filtered graph by model selection: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges") build_time = time.time() - start_time logger.info(f"Graph built in {build_time:.2f}s: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges") # Load pre-computed positions if available precomputed_positions = {} if include_positions: try: backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) root_dir = os.path.dirname(backend_dir) layout_file = os.path.join(root_dir, "precomputed_data", "force_layout_3d.pkl") if os.path.exists(layout_file): with open(layout_file, 'rb') as f: layout_data = pickle.load(f) precomputed_positions = layout_data.get('positions', {}) logger.info(f"Loaded {len(precomputed_positions):,} pre-computed positions") except Exception as e: logger.warning(f"Could not load pre-computed positions: {e}") # Build nodes list with optional pre-computed positions nodes = [] for node_id, attrs in graph.nodes(data=True): node_data = { "id": node_id, "title": attrs.get('title', node_id), "freq": attrs.get('freq', 0), "likes": attrs.get('likes', 0), "downloads": attrs.get('downloads', 0), "library": attrs.get('library', ''), "pipeline": attrs.get('pipeline', '') } # Add pre-computed position if available if node_id in precomputed_positions: pos = precomputed_positions[node_id] node_data['x'] = pos[0] node_data['y'] = pos[1] node_data['z'] = pos[2] nodes.append(node_data) logger.info(f"Processed {len(nodes):,} nodes") # Build links list links = [] edge_count = 0 for source, target, edge_attrs in graph.edges(data=True): link_data = { "source": source, "target": target, "edge_type": edge_attrs.get('edge_type'), "edge_types": edge_attrs.get('edge_types', []) } if include_edge_attributes: link_data.update({ "change_in_likes": edge_attrs.get('change_in_likes'), "percentage_change_in_likes": edge_attrs.get('percentage_change_in_likes'), "change_in_downloads": edge_attrs.get('change_in_downloads'), "percentage_change_in_downloads": edge_attrs.get('percentage_change_in_downloads'), "change_in_createdAt_days": edge_attrs.get('change_in_createdAt_days') }) links.append(link_data) edge_count += 1 if edge_count % 10000 == 0: logger.info(f"Processed {edge_count:,} edges...") logger.info(f"Processed {len(links):,} links") try: stats = network_builder.get_network_statistics(graph) except Exception as stats_error: logger.warning(f"Could not calculate network statistics: {stats_error}") stats = { "nodes": len(nodes), "edges": len(links), "density": 0.0, "avg_degree": 0.0, "clustering": 0.0 } total_time = time.time() - start_time logger.info(f"Full derivative network built successfully in {total_time:.2f}s") return { "nodes": nodes, "links": links, "statistics": stats } except HTTPException: # Re-raise HTTP exceptions as-is raise except DataNotLoadedError: raise HTTPException( status_code=503, detail="Model data not loaded. Please wait for the server to finish loading data." ) except Exception as e: import traceback error_trace = traceback.format_exc() logger.error(f"Error building full derivative network: {e}\n{error_trace}") error_detail = f"Error building full derivative network: {str(e)}" if isinstance(e, (ValueError, KeyError, AttributeError)): error_detail += f" (Type: {type(e).__name__})" # Provide more helpful error message if "memory" in str(e).lower() or "MemoryError" in str(type(e)): error_detail += ". The dataset may be too large. Try filtering by edge types." raise HTTPException(status_code=500, detail=error_detail) @app.get("/api/search/neighbors/{model_id}") async def get_model_neighbors( model_id: str, max_neighbors: int = Query(50, ge=1, le=200), min_weight: float = Query(0.0, ge=0.0) ): """ Find neighbors of a model in the co-occurrence network (graph-based search). Similar to graph database queries for finding connected nodes. """ if df is None: raise DataNotLoadedError() try: network_builder = ModelNetworkBuilder(df) top_models = network_builder.get_top_models_by_field(n=1000) model_ids = [mid for mid, _ in top_models] graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined') neighbors = network_builder.find_neighbors( model_id=model_id, graph=graph, max_neighbors=max_neighbors, min_weight=min_weight ) return { "model_id": model_id, "neighbors": neighbors, "count": len(neighbors) } except (ValueError, KeyError, AttributeError) as e: logger.error(f"Error finding neighbors: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Error finding neighbors: {str(e)}") @app.get("/api/search/path") async def find_path_between_models( source_id: str = Query(...), target_id: str = Query(...), max_path_length: int = Query(5, ge=1, le=10) ): """ Find shortest path between two models (graph-based search). Similar to graph database path queries. """ if df is None: raise DataNotLoadedError() try: network_builder = ModelNetworkBuilder(df) # Build network for top models (for performance) top_models = network_builder.get_top_models_by_field(n=1000) model_ids = [mid for mid, _ in top_models] graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined') path = network_builder.find_path( source_id=source_id, target_id=target_id, graph=graph, max_path_length=max_path_length ) if path is None: return { "source_id": source_id, "target_id": target_id, "path": None, "path_length": None, "found": False } return { "source_id": source_id, "target_id": target_id, "path": path, "path_length": len(path) - 1, "found": True } except Exception as e: raise HTTPException(status_code=500, detail=f"Error finding path: {str(e)}") @app.get("/api/search/cooccurrence/{model_id}") async def search_by_cooccurrence( model_id: str, max_results: int = Query(20, ge=1, le=100), min_weight: float = Query(1.0, ge=0.0) ): """ Search for models that co-occur with a query model. Similar to graph database queries for co-assignment patterns. """ if df is None: raise DataNotLoadedError() try: network_builder = ModelNetworkBuilder(df) # Build network for top models (for performance) top_models = network_builder.get_top_models_by_field(n=1000) model_ids = [mid for mid, _ in top_models] graph = network_builder.build_cooccurrence_network(model_ids, cooccurrence_method='combined') results = network_builder.search_by_cooccurrence( query_model_id=model_id, graph=graph, max_results=max_results, min_weight=min_weight ) return { "query_model": model_id, "cooccurring_models": results, "count": len(results) } except Exception as e: raise HTTPException(status_code=500, detail=f"Error searching by co-occurrence: {str(e)}") @app.get("/api/search/relationships/{model_id}") async def get_model_relationships( model_id: str, relationship_type: str = Query("all", regex="^(family|library|pipeline|tags|all)$"), max_results: int = Query(50, ge=1, le=200) ): """ Find models by specific relationship types (family, library, pipeline, tags). Similar to graph database relationship queries. """ if df is None: raise DataNotLoadedError() try: network_builder = ModelNetworkBuilder(df) related_models = network_builder.find_models_by_relationship( model_id=model_id, relationship_type=relationship_type, max_results=max_results ) return { "model_id": model_id, "relationship_type": relationship_type, "related_models": related_models, "count": len(related_models) } except Exception as e: raise HTTPException(status_code=500, detail=f"Error finding relationships: {str(e)}") @app.get("/api/model-count/current") async def get_current_model_count( use_cache: bool = Query(True), force_refresh: bool = Query(False), use_dataset_snapshot: bool = Query(False), use_models_page: bool = Query(True) ): """ Get the current number of models on Hugging Face Hub. Uses multiple strategies: models page scraping (fastest), dataset snapshot, or API. Query Parameters: use_cache: Use cached results if available (default: True) force_refresh: Force refresh even if cache is valid (default: False) use_dataset_snapshot: Use dataset snapshot for breakdowns (default: False) use_models_page: Try to get count from HF models page first (default: True) """ try: tracker = get_tracker() if use_dataset_snapshot: count_data = tracker.get_count_from_models_page() if count_data is None: count_data = tracker.get_current_model_count(use_models_page=False) else: try: from utils.data_loader import ModelDataLoader data_loader = ModelDataLoader() df = data_loader.load_data(sample_size=10000, prioritize_base_models=True) library_counts = {} pipeline_counts = {} for _, row in df.iterrows(): if pd.notna(row.get('library_name')): lib = str(row.get('library_name')) library_counts[lib] = library_counts.get(lib, 0) + 1 if pd.notna(row.get('pipeline_tag')): pipeline = str(row.get('pipeline_tag')) pipeline_counts[pipeline] = pipeline_counts.get(pipeline, 0) + 1 if len(df) > 0 and count_data["total_models"] > len(df): scale_factor = count_data["total_models"] / len(df) library_counts = {k: int(v * scale_factor) for k, v in library_counts.items()} pipeline_counts = {k: int(v * scale_factor) for k, v in pipeline_counts.items()} count_data["models_by_library"] = library_counts count_data["models_by_pipeline"] = pipeline_counts except Exception as e: logger.warning(f"Could not get breakdowns from dataset: {e}") else: count_data = tracker.get_current_model_count(use_models_page=use_models_page) return count_data except Exception as e: logger.error(f"Error fetching model count: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Error fetching model count: {str(e)}") @app.get("/api/model-count/historical") async def get_historical_model_counts( days: int = Query(30, ge=1, le=365), start_date: Optional[str] = Query(None), end_date: Optional[str] = Query(None) ): """ Get historical model counts. Args: days: Number of days to look back (if start_date not provided) start_date: Start date in ISO format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS) end_date: End date in ISO format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS) """ try: from datetime import datetime tracker = get_tracker() start = None end = None if start_date: start = datetime.fromisoformat(start_date.replace('Z', '+00:00')) if end_date: end = datetime.fromisoformat(end_date.replace('Z', '+00:00')) if start is None: from datetime import timedelta start = datetime.utcnow() - timedelta(days=days) historical = tracker.get_historical_counts(start, end) return { "counts": historical, "count": len(historical), "start_date": start.isoformat() if start else None, "end_date": end.isoformat() if end else None } except Exception as e: raise HTTPException(status_code=500, detail=f"Error fetching historical counts: {str(e)}") @app.get("/api/model-count/latest") async def get_latest_model_count(): """Get the most recently recorded model count from database.""" try: tracker = get_tracker() latest = tracker.get_latest_count() if latest is None: raise HTTPException(status_code=404, detail="No model counts recorded yet") return latest except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Error fetching latest count: {str(e)}") @app.post("/api/model-count/record") async def record_model_count( background_tasks: BackgroundTasks, use_dataset_snapshot: bool = Query(False, description="Use dataset snapshot instead of API (faster)") ): """ Record the current model count to the database. This can be called periodically (e.g., via cron job) to track growth over time. Query Parameters: use_dataset_snapshot: Use dataset snapshot instead of API (faster, default: False) """ try: tracker = get_tracker() def record(): if use_dataset_snapshot: count_data = tracker.get_count_from_dataset_snapshot() if count_data: tracker.record_count(count_data, source="dataset_snapshot") else: count_data = tracker.get_current_model_count(use_cache=False) tracker.record_count(count_data, source="api") else: count_data = tracker.get_current_model_count(use_cache=False) tracker.record_count(count_data, source="api") background_tasks.add_task(record) return { "status": "recording", "message": "Model count recording started in background", "source": "dataset_snapshot" if use_dataset_snapshot else "api" } except Exception as e: raise HTTPException(status_code=500, detail=f"Error recording model count: {str(e)}") @app.get("/api/model-count/growth") async def get_growth_stats(days: int = Query(7, ge=1, le=365)): """ Get growth statistics over the specified period. Args: days: Number of days to analyze """ try: tracker = get_tracker() stats = tracker.get_growth_stats(days) return stats except Exception as e: raise HTTPException(status_code=500, detail=f"Error calculating growth stats: {str(e)}") @app.get("/api/network/export/graphml") async def export_network_graphml( background_tasks: BackgroundTasks, library: Optional[str] = Query(None), pipeline_tag: Optional[str] = Query(None), min_downloads: int = Query(0), min_likes: int = Query(0), n: int = Query(100, ge=1, le=1000), cooccurrence_method: str = Query("combined", regex="^(parent_family|library|pipeline|tags|combined)$") ): """ Export co-occurrence network as GraphML file (for import into Gephi, Cytoscape, etc.). Similar to Open Syllabus graph export functionality. """ if df is None: raise DataNotLoadedError() try: network_builder = ModelNetworkBuilder(df) top_models = network_builder.get_top_models_by_field( library=library, pipeline_tag=pipeline_tag, min_downloads=min_downloads, min_likes=min_likes, n=n ) if not top_models: raise HTTPException(status_code=404, detail="No models found matching criteria") model_ids = [mid for mid, _ in top_models] graph = network_builder.build_cooccurrence_network( model_ids=model_ids, cooccurrence_method=cooccurrence_method ) with tempfile.NamedTemporaryFile(mode='w', suffix='.graphml', delete=False) as tmp_file: tmp_path = tmp_file.name network_builder.export_graphml(graph, tmp_path) background_tasks.add_task(os.unlink, tmp_path) return FileResponse( tmp_path, media_type='application/xml', filename=f'network_{cooccurrence_method}_{n}_models.graphml' ) except (ValueError, KeyError, AttributeError, IOError) as e: logger.error(f"Error exporting network: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Error exporting network: {str(e)}") @app.get("/api/model/{model_id}/papers") async def get_model_papers(model_id: str): """ Get arXiv papers associated with a model. Extracts arXiv IDs from model tags and fetches paper information. """ if df is None: raise DataNotLoadedError() model = df[df.get('model_id', '') == model_id] if len(model) == 0: raise HTTPException(status_code=404, detail="Model not found") model = model.iloc[0] # Extract arXiv IDs from tags tags_str = str(model.get('tags', '')) if pd.notna(model.get('tags')) else '' arxiv_ids = extract_arxiv_ids(tags_str) if not arxiv_ids: return { "model_id": model_id, "arxiv_ids": [], "papers": [] } # Fetch papers papers = await fetch_arxiv_papers(arxiv_ids[:10]) # Limit to 10 papers return { "model_id": model_id, "arxiv_ids": arxiv_ids, "papers": papers } @app.get("/api/models/minimal.bin") async def get_minimal_binary(): """ Serve the binary minimal dataset file. This is optimized for fast client-side loading. """ backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) root_dir = os.path.dirname(backend_dir) binary_path = os.path.join(root_dir, "cache", "binary", "embeddings.bin") if not os.path.exists(binary_path): raise HTTPException(status_code=404, detail="Binary dataset not found. Run export_binary.py first.") return FileResponse( binary_path, media_type="application/octet-stream", headers={ "Content-Disposition": "attachment; filename=embeddings.bin", "Cache-Control": "public, max-age=3600" } ) @app.get("/api/models/model_ids.json") async def get_model_ids_json(): """Serve the model IDs JSON file.""" backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) root_dir = os.path.dirname(backend_dir) json_path = os.path.join(root_dir, "cache", "binary", "model_ids.json") if not os.path.exists(json_path): raise HTTPException(status_code=404, detail="Model IDs file not found.") return FileResponse( json_path, media_type="application/json", headers={"Cache-Control": "public, max-age=3600"} ) @app.get("/api/models/metadata.json") async def get_metadata_json(): """Serve the metadata JSON file with lookup tables.""" backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) root_dir = os.path.dirname(backend_dir) json_path = os.path.join(root_dir, "cache", "binary", "metadata.json") if not os.path.exists(json_path): raise HTTPException(status_code=404, detail="Metadata file not found.") return FileResponse( json_path, media_type="application/json", headers={"Cache-Control": "public, max-age=3600"} ) @app.get("/api/model/{model_id}/files") async def get_model_files(model_id: str, branch: str = Query("main")): """ Get file tree for a model from Hugging Face. Proxies the request to avoid CORS issues. Returns a flat list of files with path and size information. """ if not model_id or not model_id.strip(): raise HTTPException(status_code=400, detail="Invalid model ID") branches_to_try = [branch, "main", "master"] if branch not in ["main", "master"] else [branch, "main" if branch == "master" else "master"] try: async with httpx.AsyncClient(timeout=15.0) as client: for branch_name in branches_to_try: try: url = f"https://huggingface.co/api/models/{model_id}/tree/{branch_name}" response = await client.get(url) if response.status_code == 200: data = response.json() # Ensure we return an array if isinstance(data, list): return data elif isinstance(data, dict) and 'tree' in data: return data['tree'] else: return [] elif response.status_code == 404: # Try next branch continue else: logger.warning(f"Unexpected status {response.status_code} for {url}") continue except httpx.HTTPStatusError as e: if e.response.status_code == 404: continue # Try next branch logger.warning(f"HTTP error for branch {branch_name}: {e}") continue except httpx.HTTPError as e: logger.warning(f"HTTP error for branch {branch_name}: {e}") continue # All branches failed raise HTTPException( status_code=404, detail=f"File tree not found for model '{model_id}'. The model may not exist or may not have any files." ) except httpx.TimeoutException: raise HTTPException( status_code=504, detail="Request to Hugging Face timed out. Please try again later." ) except HTTPException: raise # Re-raise HTTP exceptions except Exception as e: logger.error(f"Error fetching file tree: {e}", exc_info=True) raise HTTPException( status_code=500, detail=f"Error fetching file tree: {str(e)}" ) # ============================================================================= # BACKGROUND COMPUTATION ENDPOINTS # ============================================================================= import subprocess import threading # Store for background process _background_process = None _background_lock = threading.Lock() class ComputeRequest(BaseModel): sample_size: Optional[int] = None all_models: bool = False @app.get("/api/compute/status") async def get_compute_status(): """Get the status of background pre-computation.""" from pathlib import Path root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) status_file = Path(root_dir) / "precomputed_data" / "background_status_v1.json" if status_file.exists(): import json with open(status_file, 'r') as f: status = json.load(f) # Check if process is still running global _background_process with _background_lock: if _background_process is not None: poll = _background_process.poll() if poll is None: status['process_running'] = True else: status['process_running'] = False status['process_exit_code'] = poll else: status['process_running'] = False return status # Check for existing precomputed data metadata_file = Path(root_dir) / "precomputed_data" / "metadata_v1.json" models_file = Path(root_dir) / "precomputed_data" / "models_v1.parquet" if metadata_file.exists() and models_file.exists(): import json with open(metadata_file, 'r') as f: metadata = json.load(f) return { 'status': 'completed', 'total_models': metadata.get('total_models', 0), 'created_at': metadata.get('created_at'), 'process_running': False } return { 'status': 'not_started', 'total_models': 0, 'process_running': False } @app.post("/api/compute/start") async def start_background_compute(request: ComputeRequest, background_tasks: BackgroundTasks): """Start background pre-computation of model embeddings.""" global _background_process with _background_lock: if _background_process is not None and _background_process.poll() is None: raise HTTPException( status_code=409, detail="Background computation is already running" ) # Prepare command root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) script_path = os.path.join(root_dir, "backend", "scripts", "precompute_background.py") venv_python = os.path.join(root_dir, "venv", "bin", "python") cmd = [venv_python, script_path] if request.all_models: cmd.append("--all") elif request.sample_size: cmd.extend(["--sample-size", str(request.sample_size)]) else: cmd.extend(["--sample-size", "150000"]) # Default cmd.extend(["--output-dir", os.path.join(root_dir, "precomputed_data")]) # Start process in background log_file = os.path.join(root_dir, "precompute_background.log") def run_computation(): global _background_process with open(log_file, 'w') as f: with _background_lock: _background_process = subprocess.Popen( cmd, stdout=f, stderr=subprocess.STDOUT, cwd=os.path.join(root_dir, "backend") ) _background_process.wait() thread = threading.Thread(target=run_computation, daemon=True) thread.start() sample_desc = "all models" if request.all_models else f"{request.sample_size or 150000:,} models" return { "message": f"Background computation started for {sample_desc}", "status": "starting", "log_file": log_file } @app.post("/api/compute/stop") async def stop_background_compute(): """Stop the running background computation.""" global _background_process with _background_lock: if _background_process is None or _background_process.poll() is not None: return {"message": "No computation is running"} _background_process.terminate() try: _background_process.wait(timeout=5) except subprocess.TimeoutExpired: _background_process.kill() return {"message": "Background computation stopped"} @app.get("/api/data/info") async def get_data_info(): """Get information about currently loaded data.""" df = deps.df if df is None: return { "loaded": False, "message": "No data loaded" } return { "loaded": True, "total_models": len(df), "columns": list(df.columns), "unique_libraries": int(df['library_name'].nunique()) if 'library_name' in df.columns else 0, "unique_pipelines": int(df['pipeline_tag'].nunique()) if 'pipeline_tag' in df.columns else 0, "has_3d_coords": all(col in df.columns for col in ['x_3d', 'y_3d', 'z_3d']), "has_2d_coords": all(col in df.columns for col in ['x_2d', 'y_2d']) } # ============================================================================= # STATIC FILE SERVING (for HF Spaces full-stack deployment) # ============================================================================= from fastapi.staticfiles import StaticFiles from starlette.responses import FileResponse as StarletteFileResponse # Check if frontend build exists (for HF Spaces deployment) frontend_build_path = os.path.join(os.path.dirname(backend_dir), "frontend", "build") if os.path.exists(frontend_build_path): # Serve static files from React build app.mount("/static", StaticFiles(directory=os.path.join(frontend_build_path, "static")), name="static") @app.get("/{full_path:path}") async def serve_frontend(full_path: str): """Serve React frontend for non-API routes.""" # Don't serve frontend for API routes if full_path.startswith("api/") or full_path == "docs" or full_path == "openapi.json": raise HTTPException(status_code=404, detail="Not found") # Try to serve the requested file file_path = os.path.join(frontend_build_path, full_path) if os.path.isfile(file_path): return StarletteFileResponse(file_path) # Fall back to index.html for SPA routing index_path = os.path.join(frontend_build_path, "index.html") if os.path.exists(index_path): return StarletteFileResponse(index_path) raise HTTPException(status_code=404, detail="Not found") if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)