|
|
"""Inference utilities for API.""" |
|
|
|
|
|
import torch |
|
|
from typing import List, Optional, Dict |
|
|
import logging |
|
|
|
|
|
from models.transformer_model import RussianNewsClassifier |
|
|
from utils.tokenization import RussianTextTokenizer |
|
|
from utils.russian_text_utils import prepare_text_for_tokenization |
|
|
from api.schemas import TagPrediction |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ModelInference: |
|
|
""" |
|
|
Model inference handler. |
|
|
|
|
|
Handles model loading, caching, and async inference. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str, |
|
|
tokenizer_name: str = "DeepPavlov/rubert-base-cased", |
|
|
device: Optional[torch.device] = None, |
|
|
): |
|
|
""" |
|
|
Initialize inference handler. |
|
|
|
|
|
Args: |
|
|
model_path: Path to model checkpoint |
|
|
tokenizer_name: HuggingFace tokenizer name |
|
|
device: Device for inference |
|
|
""" |
|
|
self.model_path = model_path |
|
|
self.tokenizer_name = tokenizer_name |
|
|
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.tag_to_idx = None |
|
|
self.loaded = False |
|
|
|
|
|
def load_model(self) -> None: |
|
|
"""Load model and tokenizer.""" |
|
|
try: |
|
|
logger.info(f"Loading model from {self.model_path}") |
|
|
|
|
|
|
|
|
from utils.tokenization import create_tokenizer |
|
|
self.tokenizer = create_tokenizer(self.tokenizer_name) |
|
|
logger.info("Tokenizer loaded") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(self.model_path, map_location=self.device) |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict): |
|
|
if 'model' in checkpoint: |
|
|
self.model = checkpoint['model'] |
|
|
elif 'state_dict' in checkpoint: |
|
|
num_labels = checkpoint.get('num_labels', 1000) |
|
|
self.model = RussianNewsClassifier( |
|
|
model_name=self.tokenizer_name, |
|
|
num_labels=num_labels, |
|
|
use_snippet=True, |
|
|
) |
|
|
self.model.load_state_dict(checkpoint['state_dict']) |
|
|
else: |
|
|
self.model = checkpoint |
|
|
else: |
|
|
self.model = checkpoint |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict) and 'tag_to_idx' in checkpoint: |
|
|
self.tag_to_idx = checkpoint['tag_to_idx'] |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
self.loaded = True |
|
|
|
|
|
logger.info(f"Model loaded successfully on {self.device}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
self.loaded = False |
|
|
raise |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
title: str, |
|
|
snippet: Optional[str] = None, |
|
|
threshold: float = 0.5, |
|
|
top_k: Optional[int] = None, |
|
|
) -> List[TagPrediction]: |
|
|
""" |
|
|
Run inference. |
|
|
|
|
|
Args: |
|
|
title: Article title |
|
|
snippet: Optional article snippet |
|
|
threshold: Classification threshold |
|
|
top_k: Return top K predictions |
|
|
|
|
|
Returns: |
|
|
List of tag predictions |
|
|
""" |
|
|
if not self.loaded: |
|
|
raise RuntimeError("Model not loaded") |
|
|
|
|
|
|
|
|
title_clean = prepare_text_for_tokenization(title) |
|
|
snippet_clean = prepare_text_for_tokenization(snippet) if snippet else None |
|
|
|
|
|
|
|
|
title_encoded = self.tokenizer.encode( |
|
|
title_clean, |
|
|
max_length=128, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
) |
|
|
|
|
|
title_input_ids = title_encoded['input_ids'].unsqueeze(0).to(self.device) |
|
|
title_attention_mask = title_encoded['attention_mask'].unsqueeze(0).to(self.device) |
|
|
|
|
|
snippet_input_ids = None |
|
|
snippet_attention_mask = None |
|
|
|
|
|
if snippet_clean: |
|
|
snippet_encoded = self.tokenizer.encode( |
|
|
snippet_clean, |
|
|
max_length=256, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
) |
|
|
snippet_input_ids = snippet_encoded['input_ids'].unsqueeze(0).to(self.device) |
|
|
snippet_attention_mask = snippet_encoded['attention_mask'].unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = self.model( |
|
|
title_input_ids=title_input_ids, |
|
|
title_attention_mask=title_attention_mask, |
|
|
snippet_input_ids=snippet_input_ids, |
|
|
snippet_attention_mask=snippet_attention_mask, |
|
|
) |
|
|
|
|
|
probs = torch.sigmoid(logits).cpu().numpy()[0] |
|
|
|
|
|
|
|
|
predictions = [] |
|
|
|
|
|
if self.tag_to_idx: |
|
|
|
|
|
idx_to_tag = {v: k for k, v in self.tag_to_idx.items()} |
|
|
for idx, prob in enumerate(probs): |
|
|
if prob >= threshold: |
|
|
tag = idx_to_tag.get(idx, f"tag_{idx}") |
|
|
predictions.append(TagPrediction(tag=tag, score=float(prob))) |
|
|
else: |
|
|
|
|
|
for idx, prob in enumerate(probs): |
|
|
if prob >= threshold: |
|
|
predictions.append(TagPrediction(tag=f"tag_{idx}", score=float(prob))) |
|
|
|
|
|
|
|
|
predictions.sort(key=lambda x: x.score, reverse=True) |
|
|
|
|
|
if top_k: |
|
|
predictions = predictions[:top_k] |
|
|
|
|
|
return predictions |
|
|
|
|
|
|