Yassine Mhirsi
Enhance KpaModelManager to load fine-tuned weights from Hugging Face and update requirements to include huggingface_hub
d289997
| """Model manager for keypoint–argument matching model""" | |
| import os | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from huggingface_hub import hf_hub_download | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class KpaModelManager: | |
| """Manages loading and inference for keypoint matching model""" | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.model_loaded = False | |
| self.max_length = 256 | |
| self.model_id = None | |
| def load_model(self, model_id: str, api_key: str = None): | |
| """Load model with weights from Hugging Face repository""" | |
| if self.model_loaded: | |
| logger.info("KPA model already loaded") | |
| return | |
| try: | |
| logger.info(f"Loading KPA model from Hugging Face: {model_id}") | |
| # Determine device | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| # Store model ID | |
| self.model_id = model_id | |
| # Prepare token for authentication if API key is provided | |
| token = api_key if api_key else None | |
| # Load base tokenizer (distilbert-base-uncased) | |
| base_model_name = "distilbert-base-uncased" | |
| logger.info(f"Loading tokenizer from {base_model_name}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| # Load base model architecture | |
| logger.info(f"Loading base model architecture from {base_model_name}...") | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| base_model_name, | |
| num_labels=2 | |
| ) | |
| # Download and load fine-tuned weights from Hugging Face | |
| logger.info(f"Downloading fine-tuned weights from {model_id}...") | |
| weights_path = hf_hub_download( | |
| repo_id=model_id, | |
| filename="modele_appariement_rapide.pth", | |
| token=token | |
| ) | |
| logger.info(f"Loading fine-tuned weights from {weights_path}...") | |
| checkpoint = torch.load(weights_path, map_location=self.device) | |
| # Load state dict | |
| if "model_state_dict" in checkpoint: | |
| self.model.load_state_dict(checkpoint["model_state_dict"]) | |
| else: | |
| self.model.load_state_dict(checkpoint) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.model_loaded = True | |
| logger.info("✓ KPA model loaded successfully from Hugging Face!") | |
| except Exception as e: | |
| logger.error(f"Error loading KPA model: {str(e)}") | |
| raise RuntimeError(f"Failed to load KPA model: {str(e)}") | |
| def predict(self, argument: str, key_point: str) -> dict: | |
| """Run a prediction for (argument, key_point)""" | |
| if not self.model_loaded: | |
| raise RuntimeError("KPA model not loaded") | |
| try: | |
| # Tokenize input | |
| encoding = self.tokenizer( | |
| argument, | |
| key_point, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=self.max_length, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Forward pass | |
| with torch.no_grad(): | |
| outputs = self.model(**encoding) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=-1) | |
| predicted_class = torch.argmax(probabilities, dim=-1).item() | |
| confidence = probabilities[0][predicted_class].item() | |
| return { | |
| "prediction": predicted_class, | |
| "confidence": confidence, | |
| "label": "apparie" if predicted_class == 1 else "non_apparie", | |
| "probabilities": { | |
| "non_apparie": probabilities[0][0].item(), | |
| "apparie": probabilities[0][1].item(), | |
| }, | |
| } | |
| except Exception as e: | |
| logger.error(f"Error during prediction: {str(e)}") | |
| raise RuntimeError(f"KPA prediction failed: {str(e)}") | |
| def get_model_info(self): | |
| return { | |
| "model_name": self.model_id, | |
| "device": str(self.device), | |
| "max_length": self.max_length, | |
| "num_labels": 2, | |
| "loaded": self.model_loaded | |
| } | |
| kpa_model_manager = KpaModelManager() | |