File size: 8,341 Bytes
3e85304 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | """
Pre-compute the full derivative network graph and save it to disk.
This allows the API to load the network instantly instead of building it on-demand.
Usage:
python scripts/precompute_network.py [--output-dir precomputed_data] [--version v1]
"""
import os
import sys
import pickle
import argparse
import logging
import time
from pathlib import Path
from typing import Optional
# Add backend to path
sys.path.insert(0, str(Path(__file__).parent.parent))
import pandas as pd
from utils.network_analysis import ModelNetworkBuilder
from utils.precomputed_loader import PrecomputedDataLoader
from utils.data_loader import ModelDataLoader
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def precompute_network(
output_dir: str = "precomputed_data",
version: str = "v1",
include_edge_attributes: bool = False,
min_downloads: int = 0,
max_nodes: Optional[int] = None,
load_from_hf: bool = False,
sample_size: Optional[int] = None
):
"""
Pre-compute the full derivative network graph for the force-directed visualization.
Args:
output_dir: Directory to save the network file
version: Version tag for the data
include_edge_attributes: Whether to calculate edge attributes
min_downloads: Minimum downloads to include a model
max_nodes: Maximum number of nodes (top N by downloads)
load_from_hf: If True, load directly from HF dataset (includes parent relationships)
sample_size: If load_from_hf=True, sample this many models (None = all models)
"""
start_time = time.time()
# Create output directory
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
logger.info("=" * 60)
logger.info("PRE-COMPUTING FULL DERIVATIVE NETWORK")
logger.info("=" * 60)
# Step 1: Load model data
logger.info("Step 1/3: Loading model data...")
if load_from_hf:
# Load directly from HF dataset (includes parent relationships)
logger.info(f"Loading directly from Hugging Face dataset (sample_size={sample_size if sample_size else 'ALL'})...")
data_loader = ModelDataLoader()
df = data_loader.load_data(sample_size=sample_size, prioritize_base_models=False)
df = data_loader.preprocess_for_embedding(df)
# Ensure model_id is set as index
if 'model_id' in df.columns:
df.set_index('model_id', drop=False, inplace=True)
# Ensure numeric columns
for col in ['downloads', 'likes']:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0).astype(int)
logger.info(f"Loaded {len(df):,} models from HF dataset")
logger.info(f"Columns: {list(df.columns)}")
# Check if parent_model column exists (needed for network edges)
if 'parent_model' not in df.columns:
logger.warning("'parent_model' column not found - network will have 0 edges!")
else:
parent_count = df['parent_model'].notna().sum()
logger.info(f"Models with parent relationships: {parent_count:,}")
else:
# Load from pre-computed files
loader = PrecomputedDataLoader(data_dir=output_dir, version=version)
if not loader.check_available():
logger.error(f"Pre-computed data not found in {output_dir}")
logger.info("Please run precompute_data.py first, download from HF Hub, or use --load-from-hf flag")
return False
try:
df, embeddings, metadata = loader.load_all()
logger.info(f"Loaded {len(df):,} models from pre-computed data")
# Check if parent_model column exists
if 'parent_model' not in df.columns:
logger.warning("'parent_model' column not found in pre-computed data - network will have 0 edges!")
except Exception as e:
logger.error(f"Failed to load data: {e}")
return False
# Step 2: Filter data if needed
if min_downloads > 0:
df = df[df.get('downloads', 0) >= min_downloads]
logger.info(f"Filtered to {len(df):,} models with >= {min_downloads} downloads")
if max_nodes and len(df) > max_nodes:
df = df.nlargest(max_nodes, 'downloads', keep='first')
logger.info(f"Limited to top {max_nodes:,} models by downloads")
# Step 3: Build network graph
logger.info("Step 2/3: Building network graph (this may take 10-30 minutes)...")
try:
network_builder = ModelNetworkBuilder(df)
graph = network_builder.build_full_derivative_network(
include_edge_attributes=include_edge_attributes,
filter_edge_types=None # Include all edge types
)
logger.info(f"Graph built: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
except Exception as e:
logger.error(f"Failed to build network graph: {e}", exc_info=True)
return False
# Step 4: Save network graph
logger.info("Step 3/3: Saving network graph to disk...")
network_file = output_path / "full_derivative_network.pkl"
try:
with open(network_file, 'wb') as f:
pickle.dump(graph, f, protocol=pickle.HIGHEST_PROTOCOL)
file_size_mb = network_file.stat().st_size / (1024 * 1024)
logger.info(f"Saved network graph to {network_file}")
logger.info(f"File size: {file_size_mb:.2f} MB")
except Exception as e:
logger.error(f"Failed to save network graph: {e}", exc_info=True)
return False
# Save metadata
metadata_file = output_path / "network_metadata.json"
import json
from datetime import datetime
network_metadata = {
"created_at": datetime.now().isoformat(),
"version": version,
"nodes": graph.number_of_nodes(),
"edges": graph.number_of_edges(),
"include_edge_attributes": include_edge_attributes,
"min_downloads": min_downloads,
"max_nodes": max_nodes,
"file_size_mb": round(file_size_mb, 2)
}
with open(metadata_file, 'w') as f:
json.dump(network_metadata, f, indent=2)
total_time = time.time() - start_time
logger.info("=" * 60)
logger.info(f"PRE-COMPUTATION COMPLETE in {total_time:.2f} seconds")
logger.info(f"Network graph: {graph.number_of_nodes():,} nodes, {graph.number_of_edges():,} edges")
logger.info(f"Saved to: {network_file}")
logger.info("=" * 60)
return True
if __name__ == "__main__":
import time
parser = argparse.ArgumentParser(description="Pre-compute full derivative network graph")
parser.add_argument("--output-dir", type=str, default="precomputed_data",
help="Output directory for pre-computed files")
parser.add_argument("--version", type=str, default="v1",
help="Version tag for the data")
parser.add_argument("--include-edge-attributes", action="store_true",
help="Include edge attributes (slower but more detailed)")
parser.add_argument("--min-downloads", type=int, default=0,
help="Minimum downloads to include a model")
parser.add_argument("--max-nodes", type=int, default=None,
help="Maximum number of nodes (top N by downloads)")
parser.add_argument("--load-from-hf", action="store_true",
help="Load directly from HF dataset instead of pre-computed files (includes parent relationships)")
parser.add_argument("--sample-size", type=int, default=None,
help="If --load-from-hf, sample this many models (default: all models, use 0 for all)")
args = parser.parse_args()
success = precompute_network(
output_dir=args.output_dir,
version=args.version,
include_edge_attributes=args.include_edge_attributes,
min_downloads=args.min_downloads,
max_nodes=args.max_nodes,
load_from_hf=args.load_from_hf,
sample_size=None if args.sample_size == 0 else args.sample_size
)
sys.exit(0 if success else 1)
|