hf-viz / backend /scripts /precompute_network.py
midah's picture
Add network pre-computation, styling improvements, and theme toggle
3e85304
"""
Pre-compute the full derivative network graph and save it to disk.
This allows the API to load the network instantly instead of building it on-demand.
Usage:
python scripts/precompute_network.py [--output-dir precomputed_data] [--version v1]
"""
import os
import sys
import pickle
import argparse
import logging
import time
from pathlib import Path
from typing import Optional
# Add backend to path
sys.path.insert(0, str(Path(__file__).parent.parent))
import pandas as pd
from utils.network_analysis import ModelNetworkBuilder
from utils.precomputed_loader import PrecomputedDataLoader
from utils.data_loader import ModelDataLoader
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def precompute_network(
output_dir: str = "precomputed_data",
version: str = "v1",
include_edge_attributes: bool = False,
min_downloads: int = 0,
max_nodes: Optional[int] = None,
load_from_hf: bool = False,
sample_size: Optional[int] = None
):
"""
Pre-compute the full derivative network graph for the force-directed visualization.
Args:
output_dir: Directory to save the network file
version: Version tag for the data
include_edge_attributes: Whether to calculate edge attributes
min_downloads: Minimum downloads to include a model
max_nodes: Maximum number of nodes (top N by downloads)
load_from_hf: If True, load directly from HF dataset (includes parent relationships)
sample_size: If load_from_hf=True, sample this many models (None = all models)
"""
start_time = time.time()
# Create output directory
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
logger.info("=" * 60)
logger.info("PRE-COMPUTING FULL DERIVATIVE NETWORK")
logger.info("=" * 60)
# Step 1: Load model data
logger.info("Step 1/3: Loading model data...")
if load_from_hf:
# Load directly from HF dataset (includes parent relationships)
logger.info(f"Loading directly from Hugging Face dataset (sample_size={sample_size if sample_size else 'ALL'})...")
data_loader = ModelDataLoader()
df = data_loader.load_data(sample_size=sample_size, prioritize_base_models=False)
df = data_loader.preprocess_for_embedding(df)
# Ensure model_id is set as index
if 'model_id' in df.columns:
df.set_index('model_id', drop=False, inplace=True)
# Ensure numeric columns
for col in ['downloads', 'likes']:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0).astype(int)
logger.info(f"Loaded {len(df):,} models from HF dataset")
logger.info(f"Columns: {list(df.columns)}")
# Check if parent_model column exists (needed for network edges)
if 'parent_model' not in df.columns:
logger.warning("'parent_model' column not found - network will have 0 edges!")
else:
parent_count = df['parent_model'].notna().sum()
logger.info(f"Models with parent relationships: {parent_count:,}")
else:
# Load from pre-computed files
loader = PrecomputedDataLoader(data_dir=output_dir, version=version)
if not loader.check_available():
logger.error(f"Pre-computed data not found in {output_dir}")
logger.info("Please run precompute_data.py first, download from HF Hub, or use --load-from-hf flag")
return False
try:
df, embeddings, metadata = loader.load_all()
logger.info(f"Loaded {len(df):,} models from pre-computed data")
# Check if parent_model column exists
if 'parent_model' not in df.columns:
logger.warning("'parent_model' column not found in pre-computed data - network will have 0 edges!")
except Exception as e:
logger.error(f"Failed to load data: {e}")
return False
# Step 2: Filter data if needed
if min_downloads > 0:
df = df[df.get('downloads', 0) >= min_downloads]
logger.info(f"Filtered to {len(df):,} models with >= {min_downloads} downloads")
if max_nodes and len(df) > max_nodes:
df = df.nlargest(max_nodes, 'downloads', keep='first')
logger.info(f"Limited to top {max_nodes:,} models by downloads")
# Step 3: Build network graph
logger.info("Step 2/3: Building network graph (this may take 10-30 minutes)...")
try:
network_builder = ModelNetworkBuilder(df)
graph = network_builder.build_full_derivative_network(
include_edge_attributes=include_edge_attributes,
filter_edge_types=None # Include all edge types
)
logger.info(f"Graph built: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
except Exception as e:
logger.error(f"Failed to build network graph: {e}", exc_info=True)
return False
# Step 4: Save network graph
logger.info("Step 3/3: Saving network graph to disk...")
network_file = output_path / "full_derivative_network.pkl"
try:
with open(network_file, 'wb') as f:
pickle.dump(graph, f, protocol=pickle.HIGHEST_PROTOCOL)
file_size_mb = network_file.stat().st_size / (1024 * 1024)
logger.info(f"Saved network graph to {network_file}")
logger.info(f"File size: {file_size_mb:.2f} MB")
except Exception as e:
logger.error(f"Failed to save network graph: {e}", exc_info=True)
return False
# Save metadata
metadata_file = output_path / "network_metadata.json"
import json
from datetime import datetime
network_metadata = {
"created_at": datetime.now().isoformat(),
"version": version,
"nodes": graph.number_of_nodes(),
"edges": graph.number_of_edges(),
"include_edge_attributes": include_edge_attributes,
"min_downloads": min_downloads,
"max_nodes": max_nodes,
"file_size_mb": round(file_size_mb, 2)
}
with open(metadata_file, 'w') as f:
json.dump(network_metadata, f, indent=2)
total_time = time.time() - start_time
logger.info("=" * 60)
logger.info(f"PRE-COMPUTATION COMPLETE in {total_time:.2f} seconds")
logger.info(f"Network graph: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
logger.info(f"Saved to: {network_file}")
logger.info("=" * 60)
return True
if __name__ == "__main__":
import time
parser = argparse.ArgumentParser(description="Pre-compute full derivative network graph")
parser.add_argument("--output-dir", type=str, default="precomputed_data",
help="Output directory for pre-computed files")
parser.add_argument("--version", type=str, default="v1",
help="Version tag for the data")
parser.add_argument("--include-edge-attributes", action="store_true",
help="Include edge attributes (slower but more detailed)")
parser.add_argument("--min-downloads", type=int, default=0,
help="Minimum downloads to include a model")
parser.add_argument("--max-nodes", type=int, default=None,
help="Maximum number of nodes (top N by downloads)")
parser.add_argument("--load-from-hf", action="store_true",
help="Load directly from HF dataset instead of pre-computed files (includes parent relationships)")
parser.add_argument("--sample-size", type=int, default=None,
help="If --load-from-hf, sample this many models (default: all models, use 0 for all)")
args = parser.parse_args()
success = precompute_network(
output_dir=args.output_dir,
version=args.version,
include_edge_attributes=args.include_edge_attributes,
min_downloads=args.min_downloads,
max_nodes=args.max_nodes,
load_from_hf=args.load_from_hf,
sample_size=None if args.sample_size == 0 else args.sample_size
)
sys.exit(0 if success else 1)