"""Model manager for keypoint–argument matching model""" import os import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from huggingface_hub import hf_hub_download import logging logger = logging.getLogger(__name__) class KpaModelManager: """Manages loading and inference for keypoint matching model""" def __init__(self): self.model = None self.tokenizer = None self.device = None self.model_loaded = False self.max_length = 256 self.model_id = None def load_model(self, model_id: str, api_key: str = None): """Load model with weights from Hugging Face repository""" if self.model_loaded: logger.info("KPA model already loaded") return try: logger.info(f"Loading KPA model from Hugging Face: {model_id}") # Determine device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Store model ID self.model_id = model_id # Prepare token for authentication if API key is provided token = api_key if api_key else None # Load base tokenizer (distilbert-base-uncased) base_model_name = "distilbert-base-uncased" logger.info(f"Loading tokenizer from {base_model_name}...") self.tokenizer = AutoTokenizer.from_pretrained(base_model_name) # Load base model architecture logger.info(f"Loading base model architecture from {base_model_name}...") self.model = AutoModelForSequenceClassification.from_pretrained( base_model_name, num_labels=2 ) # Download and load fine-tuned weights from Hugging Face logger.info(f"Downloading fine-tuned weights from {model_id}...") weights_path = hf_hub_download( repo_id=model_id, filename="modele_appariement_rapide.pth", token=token ) logger.info(f"Loading fine-tuned weights from {weights_path}...") checkpoint = torch.load(weights_path, map_location=self.device) # Load state dict if "model_state_dict" in checkpoint: self.model.load_state_dict(checkpoint["model_state_dict"]) else: self.model.load_state_dict(checkpoint) self.model.to(self.device) self.model.eval() self.model_loaded = True logger.info("✓ KPA model loaded successfully from Hugging Face!") except Exception as e: logger.error(f"Error loading KPA model: {str(e)}") raise RuntimeError(f"Failed to load KPA model: {str(e)}") def predict(self, argument: str, key_point: str) -> dict: """Run a prediction for (argument, key_point)""" if not self.model_loaded: raise RuntimeError("KPA model not loaded") try: # Tokenize input encoding = self.tokenizer( argument, key_point, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt" ).to(self.device) # Forward pass with torch.no_grad(): outputs = self.model(**encoding) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) predicted_class = torch.argmax(probabilities, dim=-1).item() confidence = probabilities[0][predicted_class].item() return { "prediction": predicted_class, "confidence": confidence, "label": "apparie" if predicted_class == 1 else "non_apparie", "probabilities": { "non_apparie": probabilities[0][0].item(), "apparie": probabilities[0][1].item(), }, } except Exception as e: logger.error(f"Error during prediction: {str(e)}") raise RuntimeError(f"KPA prediction failed: {str(e)}") def get_model_info(self): return { "model_name": self.model_id, "device": str(self.device), "max_length": self.max_length, "num_labels": 2, "loaded": self.model_loaded } kpa_model_manager = KpaModelManager()