Spaces:
Sleeping
Sleeping
| """ | |
| Real Model Loader for Hugging Face Models | |
| Manages model loading, caching, and inference | |
| Works with public HuggingFace models without requiring authentication | |
| """ | |
| import os | |
| import logging | |
| from typing import Dict, Any, Optional, List | |
| from functools import lru_cache | |
| # Required ML libraries - these MUST be installed | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModel, | |
| AutoModelForSequenceClassification, | |
| AutoModelForTokenClassification, | |
| pipeline | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Get HF token from environment (optional - most models are public) | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| if HF_TOKEN: | |
| logger.info("HF_TOKEN found - will use for gated models if needed") | |
| else: | |
| logger.info("HF_TOKEN not found - using public models only (this is normal)") | |
| class ModelLoader: | |
| """ | |
| Manages loading and caching of Hugging Face models | |
| Implements lazy loading and GPU optimization | |
| """ | |
| def __init__(self): | |
| """Initialize the model loader with GPU support if available""" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.loaded_models = {} | |
| self.model_configs = self._get_model_configs() | |
| # Log system information | |
| logger.info(f"Model Loader initialized on device: {self.device}") | |
| logger.info(f"PyTorch version: {torch.__version__}") | |
| logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
| # Verify model configs are properly loaded | |
| logger.info(f"Model configurations loaded: {len(self.model_configs)} models") | |
| for key in self.model_configs: | |
| logger.info(f" - {key}: {self.model_configs[key]['model_id']}") | |
| def _get_model_configs(self) -> Dict[str, Dict[str, Any]]: | |
| """ | |
| Configuration for real Hugging Face models | |
| Maps tasks to actual model names on Hugging Face Hub | |
| """ | |
| return { | |
| # Document Classification | |
| "document_classifier": { | |
| "model_id": "emilyalsentzer/Bio_ClinicalBERT", | |
| "task": "text-classification", | |
| "description": "Clinical document type classification" | |
| }, | |
| # Clinical NER | |
| "clinical_ner": { | |
| "model_id": "d4data/biomedical-ner-all", | |
| "task": "ner", | |
| "description": "Biomedical named entity recognition" | |
| }, | |
| # Clinical Text Generation | |
| "clinical_generation": { | |
| "model_id": "microsoft/BioGPT-Large", | |
| "task": "text-generation", | |
| "description": "Clinical text generation and summarization" | |
| }, | |
| # Medical Question Answering | |
| "medical_qa": { | |
| "model_id": "deepset/roberta-base-squad2", | |
| "task": "question-answering", | |
| "description": "Medical question answering" | |
| }, | |
| # General Medical Analysis | |
| "general_medical": { | |
| "model_id": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext", | |
| "task": "feature-extraction", | |
| "description": "General medical text understanding" | |
| }, | |
| # Drug-Drug Interaction | |
| "drug_interaction": { | |
| "model_id": "allenai/scibert_scivocab_uncased", | |
| "task": "feature-extraction", | |
| "description": "Drug interaction detection" | |
| }, | |
| # Radiology Report Generation (fallback to general medical) | |
| "radiology_generation": { | |
| "model_id": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", | |
| "task": "feature-extraction", | |
| "description": "Radiology report analysis" | |
| }, | |
| # Clinical Summarization | |
| "clinical_summarization": { | |
| "model_id": "google/bigbird-pegasus-large-pubmed", | |
| "task": "summarization", | |
| "description": "Clinical document summarization" | |
| } | |
| } | |
| def load_model(self, model_key: str) -> Optional[Any]: | |
| """ | |
| Load a model by key, with caching | |
| Most HuggingFace models are public and don't require authentication. | |
| HF_TOKEN is only needed for private/gated models. | |
| """ | |
| try: | |
| # Check if already loaded | |
| if model_key in self.loaded_models: | |
| logger.info(f"Using cached model: {model_key}") | |
| return self.loaded_models[model_key] | |
| # Get model configuration | |
| if model_key not in self.model_configs: | |
| logger.warning(f"Unknown model key: {model_key}, using fallback") | |
| model_key = "general_medical" | |
| config = self.model_configs[model_key] | |
| model_id = config["model_id"] | |
| task = config["task"] | |
| logger.info(f"Loading model: {model_id} for task: {task}") | |
| # Try loading with pipeline (works for most public models) | |
| # Pass token only if available (most models don't need it) | |
| try: | |
| pipeline_kwargs = { | |
| "task": task, | |
| "model": model_id, | |
| "device": 0 if self.device == "cuda" else -1, | |
| "trust_remote_code": True | |
| } | |
| # Only add token if it exists (avoid passing None/empty string) | |
| if HF_TOKEN: | |
| pipeline_kwargs["token"] = HF_TOKEN | |
| model_pipeline = pipeline(**pipeline_kwargs) | |
| self.loaded_models[model_key] = model_pipeline | |
| logger.info(f"Successfully loaded model: {model_id}") | |
| return model_pipeline | |
| except Exception as e: | |
| error_msg = str(e).lower() | |
| # Check if it's an authentication error | |
| if "401" in error_msg or "unauthorized" in error_msg or "authentication" in error_msg: | |
| if not HF_TOKEN: | |
| logger.error(f"Model {model_id} requires authentication but HF_TOKEN not available") | |
| logger.error("This model is gated/private. Using public alternative or fallback.") | |
| else: | |
| logger.error(f"Model {model_id} authentication failed even with HF_TOKEN") | |
| else: | |
| logger.error(f"Failed to load model {model_id}: {str(e)}") | |
| # Try loading with AutoModel as fallback | |
| try: | |
| logger.info(f"Trying alternative loading method for {model_id}...") | |
| tokenizer_kwargs = {"model_id": model_id, "trust_remote_code": True} | |
| model_kwargs = {"pretrained_model_name_or_path": model_id, "trust_remote_code": True} | |
| if HF_TOKEN: | |
| tokenizer_kwargs["token"] = HF_TOKEN | |
| model_kwargs["token"] = HF_TOKEN | |
| tokenizer = AutoTokenizer.from_pretrained(**tokenizer_kwargs) | |
| model = AutoModel.from_pretrained(**model_kwargs).to(self.device) | |
| self.loaded_models[model_key] = { | |
| "tokenizer": tokenizer, | |
| "model": model, | |
| "type": "custom" | |
| } | |
| logger.info(f"Successfully loaded {model_id} with alternative method") | |
| return self.loaded_models[model_key] | |
| except Exception as inner_e: | |
| logger.error(f"Alternative loading also failed for {model_id}: {str(inner_e)}") | |
| logger.info(f"Model {model_key} unavailable - will use fallback analysis") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Model loading failed for {model_key}: {str(e)}") | |
| return None | |
| def run_inference( | |
| self, | |
| model_key: str, | |
| input_text: str, | |
| task_params: Optional[Dict[str, Any]] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run inference on loaded model | |
| """ | |
| try: | |
| model = self.load_model(model_key) | |
| if model is None: | |
| return { | |
| "error": "Model not available", | |
| "model_key": model_key | |
| } | |
| task_params = task_params or {} | |
| # Handle pipeline models | |
| if hasattr(model, '__call__') and not isinstance(model, dict): | |
| # Truncate input to avoid token limit issues | |
| max_length = task_params.get("max_length", 512) | |
| result = model( | |
| input_text[:4000], # Limit input length | |
| max_length=max_length, | |
| truncation=True, | |
| **task_params | |
| ) | |
| return { | |
| "success": True, | |
| "result": result, | |
| "model_key": model_key | |
| } | |
| # Handle custom loaded models | |
| elif isinstance(model, dict) and model.get("type") == "custom": | |
| tokenizer = model["tokenizer"] | |
| model_obj = model["model"] | |
| inputs = tokenizer( | |
| input_text[:512], | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512 | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = model_obj(**inputs) | |
| return { | |
| "success": True, | |
| "result": { | |
| "embeddings": outputs.last_hidden_state.mean(dim=1).cpu().tolist(), | |
| "pooled": outputs.pooler_output.cpu().tolist() if hasattr(outputs, 'pooler_output') else None | |
| }, | |
| "model_key": model_key | |
| } | |
| else: | |
| return { | |
| "error": "Unknown model type", | |
| "model_key": model_key | |
| } | |
| except Exception as e: | |
| logger.error(f"Inference failed for {model_key}: {str(e)}") | |
| return { | |
| "error": str(e), | |
| "model_key": model_key | |
| } | |
| def clear_cache(self, model_key: Optional[str] = None): | |
| """Clear model cache to free memory""" | |
| if model_key: | |
| if model_key in self.loaded_models: | |
| del self.loaded_models[model_key] | |
| logger.info(f"Cleared cache for model: {model_key}") | |
| else: | |
| self.loaded_models.clear() | |
| logger.info("Cleared all model caches") | |
| # Force garbage collection and clear GPU cache if available | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def test_model_loading(self) -> Dict[str, Any]: | |
| """Test loading all configured models to verify AI functionality""" | |
| results = { | |
| "total_models": len(self.model_configs), | |
| "models_loaded": 0, | |
| "models_failed": 0, | |
| "errors": [], | |
| "device": self.device, | |
| "pytorch_version": torch.__version__ | |
| } | |
| for model_key, config in self.model_configs.items(): | |
| try: | |
| logger.info(f"Testing model: {model_key} ({config['model_id']})") | |
| # Try to load the model | |
| test_input = "Test ECG analysis request" | |
| result = self.run_inference(model_key, test_input, {"max_new_tokens": 50}) | |
| if result.get("success"): | |
| results["models_loaded"] += 1 | |
| logger.info(f"✅ {model_key}: Loaded successfully") | |
| else: | |
| results["models_failed"] += 1 | |
| error_msg = result.get("error", "Unknown error") | |
| results["errors"].append(f"{model_key}: {error_msg}") | |
| logger.warning(f"⚠️ {model_key}: {error_msg}") | |
| except Exception as e: | |
| results["models_failed"] += 1 | |
| error_msg = f"Exception during loading: {str(e)}" | |
| results["errors"].append(f"{model_key}: {error_msg}") | |
| logger.error(f"❌ {model_key}: {error_msg}") | |
| logger.info(f"Model loading test complete: {results['models_loaded']}/{results['total_models']} successful") | |
| return results | |
| # Global model loader instance | |
| _model_loader = None | |
| def get_model_loader() -> ModelLoader: | |
| """Get singleton model loader instance""" | |
| global _model_loader | |
| if _model_loader is None: | |
| _model_loader = ModelLoader() | |
| return _model_loader | |