FastAPI-Backend-Models / services /label_model_manage.py
Yassine Mhirsi
Add KPA model integration and update configuration for label predictions
f28285b
raw
history blame
3.82 kB
"""Model manager for keypoint–argument matching model"""
import os
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import logging
logger = logging.getLogger(__name__)
class KpaModelManager:
"""Manages loading and inference for keypoint matching model"""
def __init__(self):
self.model = None
self.tokenizer = None
self.device = None
self.model_loaded = False
self.max_length = 256
self.model_id = None
def load_model(self, model_id: str, api_key: str = None):
"""Load model and tokenizer from Hugging Face"""
if self.model_loaded:
logger.info("KPA model already loaded")
return
try:
logger.info(f"Loading KPA model from Hugging Face: {model_id}")
# Determine device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Store model ID
self.model_id = model_id
# Prepare token for authentication if API key is provided
token = api_key if api_key else None
# Load tokenizer and model from Hugging Face
logger.info("Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
logger.info("Loading model...")
self.model = AutoModelForSequenceClassification.from_pretrained(
model_id,
token=token,
trust_remote_code=True
)
self.model.to(self.device)
self.model.eval()
self.model_loaded = True
logger.info("✓ KPA model loaded successfully from Hugging Face!")
except Exception as e:
logger.error(f"Error loading KPA model: {str(e)}")
raise RuntimeError(f"Failed to load KPA model: {str(e)}")
def predict(self, argument: str, key_point: str) -> dict:
"""Run a prediction for (argument, key_point)"""
if not self.model_loaded:
raise RuntimeError("KPA model not loaded")
try:
# Tokenize input
encoding = self.tokenizer(
argument,
key_point,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt"
).to(self.device)
# Forward pass
with torch.no_grad():
outputs = self.model(**encoding)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
confidence = probabilities[0][predicted_class].item()
return {
"prediction": predicted_class,
"confidence": confidence,
"label": "apparie" if predicted_class == 1 else "non_apparie",
"probabilities": {
"non_apparie": probabilities[0][0].item(),
"apparie": probabilities[0][1].item(),
},
}
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
raise RuntimeError(f"KPA prediction failed: {str(e)}")
def get_model_info(self):
return {
"model_name": self.model_id,
"device": str(self.device),
"max_length": self.max_length,
"num_labels": 2,
"loaded": self.model_loaded
}
kpa_model_manager = KpaModelManager()