"""Gradio app entry point for HuggingFace Spaces. Run locally: cd web && python app.py Deploy to HF Spaces: Push the contents of ``web/`` (plus ``assets/model_pool.npz`` and the checkpoint at ``checkpoint/...``) to a new Space with sdk=gradio. """ from __future__ import annotations import os import traceback import gradio as gr import pandas as pd from recommend import default_recommender # Load once at module import time so the model is warm before the first request. print("Loading recommender ...") RECOMMENDER = default_recommender() print(f"Loaded recommender: {len(RECOMMENDER.model_names)} candidate models, " f"{len(RECOMMENDER.task2id)} tasks, {len(RECOMMENDER.metric2id)} metrics.") # Sort the dropdown choices for a sane UX. TASK_CHOICES = sorted(RECOMMENDER.task2id.keys(), key=lambda x: x.lower()) # Metric vocab is huge (3k+) and noisy — restrict to the most common bare metric names. COMMON_METRICS = [ "accuracy", "f1", "exact_match", "rouge_l", "bleu", "mean_iou", "mean_average_precision", "top_1_accuracy", "top_5_accuracy", "perplexity", "wer", "auc", "spearman", "pearson", "mse", "rmse", "mc2", "accuracy_norm", "strict_accuracy", ] # Keep only those actually present in the metric vocab (with loose alias matching). METRIC_CHOICES = sorted( {m for m in COMMON_METRICS if RECOMMENDER.resolve_metric(m) != RECOMMENDER.model.unknown_metric_id} ) if "accuracy" in COMMON_METRICS and not METRIC_CHOICES: METRIC_CHOICES = COMMON_METRICS # fallback EXAMPLE_DESCRIPTIONS = [ "MMLU is a multiple-choice benchmark covering 57 academic subjects, evaluating broad knowledge and reasoning ability across humanities, STEM, and social sciences.", "GSM8K is a dataset of 8.5K high-quality grade-school math word problems requiring multi-step arithmetic reasoning to arrive at a single numerical answer.", "ImageNet-1K contains roughly 1.28M natural images labeled with one of 1000 fine-grained object categories, widely used for image classification benchmarking.", "CoNLL 2003 is an English named-entity recognition corpus annotating persons, organizations, locations, and miscellaneous entities in news wire text.", ] def _format_size(size_b: float) -> str: """Pretty-print parameter count: '7.0B', '350M', '1.2K params', or '—' if unknown.""" if size_b is None or not (size_b == size_b) or size_b <= 0: # NaN check return "—" if size_b >= 1.0: return f"{size_b:.1f}B" if size_b >= 0.001: return f"{size_b * 1000:.0f}M" return f"{size_b * 1_000_000:.0f}K" def recommend_ui(dataset_description: str, task: str, metric: str, top_k: int, min_size: float, max_size: float, official_only: bool, hf_only: bool, api_key: str): if not (dataset_description or "").strip(): return pd.DataFrame(columns=["rank", "model", "score", "size", "popularity", "link"]), \ "Please enter a dataset description." api_key = (api_key or "").strip() if not api_key and not os.environ.get("OPENAI_API_KEY"): return pd.DataFrame(), ( "⚠️ Please paste your OpenAI API key in the field above. " "We use it once per request to embed your dataset description; " "the key is **not stored or logged** by this app." ) # 0 / blank means "no limit" on that side. min_b = float(min_size) if min_size and float(min_size) > 0 else None max_b = float(max_size) if max_size and float(max_size) > 0 else None if min_b is not None and max_b is not None and min_b > max_b: return pd.DataFrame(), "⚠️ Min size must be ≤ max size." try: recs = RECOMMENDER.recommend( dataset_description=dataset_description, task=task, metric=metric, top_k=int(top_k), popularity_weight=0.0, hf_only=bool(hf_only), min_size_b=min_b, max_size_b=max_b, official_only=bool(official_only), api_key=api_key or None, ) except ValueError as e: return pd.DataFrame(), f"⚠️ {e}" except Exception: return pd.DataFrame(), f"⚠️ Internal error:\n```\n{traceback.format_exc()}\n```" rows = [] for r in recs: link = f"[link]({r.hf_url})" if r.hf_url else "—" rows.append({ "rank": r.rank, "model": r.model_name, "score": round(r.score, 4), "size": _format_size(r.size_b), "popularity": r.popularity, "link": link, }) df = pd.DataFrame(rows, columns=["rank", "model", "score", "size", "popularity", "link"]) return df, f"Returned top-{len(rows)} of {len(RECOMMENDER.model_names)} candidates." with gr.Blocks(title="ModelLens · Finding the Best Model for Your Task", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # ModelLens: Finding the Best for Your Task from Myriads of Models Describe your dataset, pick a task type and a metric, and ModelLens returns the top candidates from a pool of **47k+** HuggingFace models. Backed by the ablation_no_id MLPMetric checkpoint trained on `unified_augmented`. > **BYO OpenAI key.** This Space embeds your dataset description with > `text-embedding-3-small`. """ ) with gr.Row(): with gr.Column(scale=2): desc = gr.Textbox( label="Dataset description", placeholder="Describe your dataset in 2-3 sentences. The more specific, the better.", lines=5, ) with gr.Row(): task = gr.Dropdown( choices=TASK_CHOICES, label="Task type", value="Question Answering" if "Question Answering" in TASK_CHOICES else TASK_CHOICES[0], filterable=True, ) metric = gr.Dropdown( choices=METRIC_CHOICES, label="Metric (optional)", value="accuracy" if "accuracy" in METRIC_CHOICES else (METRIC_CHOICES[0] if METRIC_CHOICES else None), filterable=True, allow_custom_value=True, ) top_k = gr.Slider(5, 100, value=20, step=5, label="Top-k") api_key = gr.Textbox( label="OpenAI API key (sk-...)", placeholder="Paste your key — used once per request, never stored or logged.", type="password", lines=1, ) with gr.Row(): min_size = gr.Number( value=0, label="Min size (B params, 0 = no min)", minimum=0, precision=2, ) max_size = gr.Number( value=0, label="Max size (B params, 0 = no max)", minimum=0, precision=2, ) official_only = gr.Checkbox( value=False, label="Only show official pretrained models (DeepSeek, Qwen, Llama, gpt-oss, Mistral, Gemma, Phi, ...)", ) hf_only = gr.Checkbox( value=True, label="Only show models hosted on HuggingFace (drops paper baselines like 'inceptionv4')", ) run_btn = gr.Button("Search", variant="primary") gr.Examples( examples=[[d] for d in EXAMPLE_DESCRIPTIONS], inputs=[desc], outputs=[], label="Example dataset descriptions (click to fill, then press Search)", run_on_click=False, ) with gr.Column(scale=3): status = gr.Markdown("") table = gr.Dataframe( headers=["rank", "model", "score", "size", "popularity", "link"], interactive=False, wrap=True, datatype=["number", "str", "number", "str", "number", "markdown"], ) run_btn.click( recommend_ui, inputs=[desc, task, metric, top_k, min_size, max_size, official_only, hf_only, api_key], outputs=[table, status], ) if __name__ == "__main__": demo.queue(max_size=16).launch( server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"), server_port=int(os.environ.get("GRADIO_SERVER_PORT", 7860)), share=False, )