""" Recommend Hugging Face models based on the app plan, size preference, and GPU needs. """ from typing import Optional # Curated model recommendations by task and size MODEL_CATALOG = { "text-generation": { "small": [ {"id": "HuggingFaceTB/SmolLM2-360M-Instruct", "desc": "Compact instruct model, very fast", "size": "360M"}, ], "medium": [ {"id": "Qwen/Qwen2.5-7B-Instruct", "desc": "Strong general-purpose 7B instruct model", "size": "7B"}, {"id": "mistralai/Mistral-7B-Instruct-v0.3", "desc": "Fast and capable instruct model", "size": "7B"}, ], "large": [ {"id": "meta-llama/Llama-3.1-70B-Instruct", "desc": "Top-tier large language model", "size": "70B"}, {"id": "Qwen/Qwen2.5-72B-Instruct", "desc": "Excellent multilingual reasoning", "size": "72B"}, ], }, "text-classification": { "small": [ {"id": "distilbert-base-uncased-finetuned-sst-2-english", "desc": "Fast sentiment classifier", "size": "66M"}, ], "medium": [ {"id": "cardiffnlp/twitter-roberta-base-sentiment-latest", "desc": "Social media sentiment analysis", "size": "125M"}, {"id": "j-hartmann/emotion-english-distilroberta-base", "desc": "Multi-emotion classifier", "size": "82M"}, ], "large": [ {"id": "SamLowe/roberta-base-go_emotions", "desc": "28-class emotion detection", "size": "125M"}, ], }, "summarization": { "small": [ {"id": "sshleifer/distilbart-cnn-12-6", "desc": "Compact summarization model", "size": "306M"}, ], "medium": [ {"id": "facebook/bart-large-cnn", "desc": "Strong CNN/DailyMail summarizer", "size": "406M"}, {"id": "google/pegasus-xsum", "desc": "Abstractive summarization", "size": "568M"}, ], "large": [ {"id": "google/pegasus-large", "desc": "High-quality abstractive summaries", "size": "568M"}, ], }, "translation": { "small": [ {"id": "Helsinki-NLP/opus-mt-en-fr", "desc": "English to French translation", "size": "298M"}, ], "medium": [ {"id": "facebook/mbart-large-50-many-to-many-mmt", "desc": "50-language translation", "size": "611M"}, ], "large": [ {"id": "facebook/nllb-200-3.3B", "desc": "200-language translation model", "size": "3.3B"}, ], }, "image-classification": { "small": [ {"id": "google/mobilenet_v2_1.0_224", "desc": "Mobile-optimized classifier", "size": "3.4M"}, ], "medium": [ {"id": "microsoft/resnet-50", "desc": "Classic ResNet-50 ImageNet classifier", "size": "25.6M"}, {"id": "google/vit-base-patch16-224", "desc": "Vision Transformer classifier", "size": "86M"}, ], "large": [ {"id": "google/vit-large-patch16-224", "desc": "Large Vision Transformer", "size": "304M"}, ], }, "object-detection": { "small": [ {"id": "hustvl/yolos-tiny", "desc": "Tiny YOLO-style detector", "size": "6.5M"}, ], "medium": [ {"id": "facebook/detr-resnet-50", "desc": "DETR object detector", "size": "41M"}, ], "large": [ {"id": "facebook/detr-resnet-101", "desc": "Large DETR detector", "size": "60M"}, ], }, "text-to-image": { "small": [ {"id": "segmind/SSD-1B", "desc": "Compact SD distilled model", "size": "1.3B"}, ], "medium": [ {"id": "stabilityai/stable-diffusion-xl-base-1.0", "desc": "SDXL base model", "size": "3.5B"}, ], "large": [ {"id": "black-forest-labs/FLUX.1-dev", "desc": "State-of-the-art image gen", "size": "12B"}, ], }, "automatic-speech-recognition": { "small": [ {"id": "openai/whisper-tiny", "desc": "Tiny Whisper ASR", "size": "39M"}, ], "medium": [ {"id": "openai/whisper-base", "desc": "Whisper base ASR model", "size": "74M"}, {"id": "openai/whisper-medium", "desc": "Whisper medium ASR model", "size": "769M"}, ], "large": [ {"id": "openai/whisper-large-v3", "desc": "Best Whisper ASR model", "size": "1.5B"}, ], }, "question-answering": { "small": [ {"id": "distilbert-base-cased-distilled-squad", "desc": "Fast QA model", "size": "66M"}, ], "medium": [ {"id": "deepset/roberta-base-squad2", "desc": "RoBERTa QA on SQuAD2", "size": "125M"}, ], "large": [ {"id": "deepset/deberta-v3-large-squad2", "desc": "DeBERTa large QA", "size": "304M"}, ], }, "token-classification": { "small": [ {"id": "dslim/bert-base-NER", "desc": "BERT NER model", "size": "110M"}, ], "medium": [ {"id": "Jean-Baptiste/roberta-large-ner-english", "desc": "Large NER model", "size": "355M"}, ], "large": [ {"id": "Jean-Baptiste/roberta-large-ner-english", "desc": "Large NER model", "size": "355M"}, ], }, } # GPU recommendation thresholds GPU_THRESHOLDS = { "small": False, "medium": False, # most medium models run on CPU "large": True, } class ModelRecommender: """Recommend HF models based on app plan and user preferences.""" def recommend( self, plan: dict, model_size: str = "medium", gpu_needed: bool = False, ) -> list: """ Return a list of recommended model dicts. Each dict has: id, desc, size, gpu_recommended """ task = plan.get("model_task") if not task: return [] # Normalize size if model_size not in ("small", "medium", "large"): model_size = "medium" task_models = MODEL_CATALOG.get(task, {}) candidates = task_models.get(model_size, []) # Fallback: try adjacent sizes if not candidates: for fallback in ("medium", "small", "large"): candidates = task_models.get(fallback, []) if candidates: break # If still nothing, provide a generic suggestion if not candidates: candidates = [ { "id": f"models?pipeline_tag={task}", "desc": f"Search HF Hub for {task} models", "size": "varies", } ] # Annotate with GPU recommendation results = [] for model in candidates: m = dict(model) m["gpu_recommended"] = gpu_needed or GPU_THRESHOLDS.get(model_size, False) results.append(m) return results def get_primary_model(self, plan: dict, model_size: str = "medium") -> Optional[str]: """Get the single best model ID for the plan.""" models = self.recommend(plan, model_size) if models: return models[0]["id"] return None