Solareva Taisia
chore(release): initial public snapshot
198ccb0
"""Inference utilities for API."""
import torch
from typing import List, Optional, Dict
import logging
from models.transformer_model import RussianNewsClassifier
from utils.tokenization import RussianTextTokenizer
from utils.russian_text_utils import prepare_text_for_tokenization
from api.schemas import TagPrediction
logger = logging.getLogger(__name__)
class ModelInference:
"""
Model inference handler.
Handles model loading, caching, and async inference.
"""
def __init__(
self,
model_path: str,
tokenizer_name: str = "DeepPavlov/rubert-base-cased",
device: Optional[torch.device] = None,
):
"""
Initialize inference handler.
Args:
model_path: Path to model checkpoint
tokenizer_name: HuggingFace tokenizer name
device: Device for inference
"""
self.model_path = model_path
self.tokenizer_name = tokenizer_name
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
self.tokenizer = None
self.tag_to_idx = None
self.loaded = False
def load_model(self) -> None:
"""Load model and tokenizer."""
try:
logger.info(f"Loading model from {self.model_path}")
# Load tokenizer
from utils.tokenization import create_tokenizer
self.tokenizer = create_tokenizer(self.tokenizer_name)
logger.info("Tokenizer loaded")
# Load model
checkpoint = torch.load(self.model_path, map_location=self.device)
# Handle different checkpoint formats
if isinstance(checkpoint, dict):
if 'model' in checkpoint:
self.model = checkpoint['model']
elif 'state_dict' in checkpoint:
num_labels = checkpoint.get('num_labels', 1000)
self.model = RussianNewsClassifier(
model_name=self.tokenizer_name,
num_labels=num_labels,
use_snippet=True,
)
self.model.load_state_dict(checkpoint['state_dict'])
else:
self.model = checkpoint
else:
self.model = checkpoint
# Load tag mapping if available
if isinstance(checkpoint, dict) and 'tag_to_idx' in checkpoint:
self.tag_to_idx = checkpoint['tag_to_idx']
self.model.to(self.device)
self.model.eval()
self.loaded = True
logger.info(f"Model loaded successfully on {self.device}")
except Exception as e:
logger.error(f"Failed to load model: {e}")
self.loaded = False
raise
def predict(
self,
title: str,
snippet: Optional[str] = None,
threshold: float = 0.5,
top_k: Optional[int] = None,
) -> List[TagPrediction]:
"""
Run inference.
Args:
title: Article title
snippet: Optional article snippet
threshold: Classification threshold
top_k: Return top K predictions
Returns:
List of tag predictions
"""
if not self.loaded:
raise RuntimeError("Model not loaded")
# Prepare text
title_clean = prepare_text_for_tokenization(title)
snippet_clean = prepare_text_for_tokenization(snippet) if snippet else None
# Tokenize
title_encoded = self.tokenizer.encode(
title_clean,
max_length=128,
padding='max_length',
truncation=True,
)
title_input_ids = title_encoded['input_ids'].unsqueeze(0).to(self.device)
title_attention_mask = title_encoded['attention_mask'].unsqueeze(0).to(self.device)
snippet_input_ids = None
snippet_attention_mask = None
if snippet_clean:
snippet_encoded = self.tokenizer.encode(
snippet_clean,
max_length=256,
padding='max_length',
truncation=True,
)
snippet_input_ids = snippet_encoded['input_ids'].unsqueeze(0).to(self.device)
snippet_attention_mask = snippet_encoded['attention_mask'].unsqueeze(0).to(self.device)
# Inference
with torch.no_grad():
logits = self.model(
title_input_ids=title_input_ids,
title_attention_mask=title_attention_mask,
snippet_input_ids=snippet_input_ids,
snippet_attention_mask=snippet_attention_mask,
)
probs = torch.sigmoid(logits).cpu().numpy()[0]
# Convert to predictions
predictions = []
if self.tag_to_idx:
# Use provided tag mapping
idx_to_tag = {v: k for k, v in self.tag_to_idx.items()}
for idx, prob in enumerate(probs):
if prob >= threshold:
tag = idx_to_tag.get(idx, f"tag_{idx}")
predictions.append(TagPrediction(tag=tag, score=float(prob)))
else:
# Generic tag indices
for idx, prob in enumerate(probs):
if prob >= threshold:
predictions.append(TagPrediction(tag=f"tag_{idx}", score=float(prob)))
# Sort by score and apply top_k
predictions.sort(key=lambda x: x.score, reverse=True)
if top_k:
predictions = predictions[:top_k]
return predictions