File size: 5,880 Bytes
198ccb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Inference utilities for API."""

import torch
from typing import List, Optional, Dict
import logging

from models.transformer_model import RussianNewsClassifier
from utils.tokenization import RussianTextTokenizer
from utils.russian_text_utils import prepare_text_for_tokenization
from api.schemas import TagPrediction

logger = logging.getLogger(__name__)


class ModelInference:
    """
    Model inference handler.
    
    Handles model loading, caching, and async inference.
    """

    def __init__(
        self,
        model_path: str,
        tokenizer_name: str = "DeepPavlov/rubert-base-cased",
        device: Optional[torch.device] = None,
    ):
        """
        Initialize inference handler.
        
        Args:
            model_path: Path to model checkpoint
            tokenizer_name: HuggingFace tokenizer name
            device: Device for inference
        """
        self.model_path = model_path
        self.tokenizer_name = tokenizer_name
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.model = None
        self.tokenizer = None
        self.tag_to_idx = None
        self.loaded = False

    def load_model(self) -> None:
        """Load model and tokenizer."""
        try:
            logger.info(f"Loading model from {self.model_path}")
            
            # Load tokenizer
            from utils.tokenization import create_tokenizer
            self.tokenizer = create_tokenizer(self.tokenizer_name)
            logger.info("Tokenizer loaded")
            
            # Load model
            checkpoint = torch.load(self.model_path, map_location=self.device)
            
            # Handle different checkpoint formats
            if isinstance(checkpoint, dict):
                if 'model' in checkpoint:
                    self.model = checkpoint['model']
                elif 'state_dict' in checkpoint:
                    num_labels = checkpoint.get('num_labels', 1000)
                    self.model = RussianNewsClassifier(
                        model_name=self.tokenizer_name,
                        num_labels=num_labels,
                        use_snippet=True,
                    )
                    self.model.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model = checkpoint
            else:
                self.model = checkpoint
            
            # Load tag mapping if available
            if isinstance(checkpoint, dict) and 'tag_to_idx' in checkpoint:
                self.tag_to_idx = checkpoint['tag_to_idx']
            
            self.model.to(self.device)
            self.model.eval()
            self.loaded = True
            
            logger.info(f"Model loaded successfully on {self.device}")
            
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            self.loaded = False
            raise

    def predict(
        self,
        title: str,
        snippet: Optional[str] = None,
        threshold: float = 0.5,
        top_k: Optional[int] = None,
    ) -> List[TagPrediction]:
        """
        Run inference.
        
        Args:
            title: Article title
            snippet: Optional article snippet
            threshold: Classification threshold
            top_k: Return top K predictions
            
        Returns:
            List of tag predictions
        """
        if not self.loaded:
            raise RuntimeError("Model not loaded")
        
        # Prepare text
        title_clean = prepare_text_for_tokenization(title)
        snippet_clean = prepare_text_for_tokenization(snippet) if snippet else None
        
        # Tokenize
        title_encoded = self.tokenizer.encode(
            title_clean,
            max_length=128,
            padding='max_length',
            truncation=True,
        )
        
        title_input_ids = title_encoded['input_ids'].unsqueeze(0).to(self.device)
        title_attention_mask = title_encoded['attention_mask'].unsqueeze(0).to(self.device)
        
        snippet_input_ids = None
        snippet_attention_mask = None
        
        if snippet_clean:
            snippet_encoded = self.tokenizer.encode(
                snippet_clean,
                max_length=256,
                padding='max_length',
                truncation=True,
            )
            snippet_input_ids = snippet_encoded['input_ids'].unsqueeze(0).to(self.device)
            snippet_attention_mask = snippet_encoded['attention_mask'].unsqueeze(0).to(self.device)
        
        # Inference
        with torch.no_grad():
            logits = self.model(
                title_input_ids=title_input_ids,
                title_attention_mask=title_attention_mask,
                snippet_input_ids=snippet_input_ids,
                snippet_attention_mask=snippet_attention_mask,
            )
            
            probs = torch.sigmoid(logits).cpu().numpy()[0]
        
        # Convert to predictions
        predictions = []
        
        if self.tag_to_idx:
            # Use provided tag mapping
            idx_to_tag = {v: k for k, v in self.tag_to_idx.items()}
            for idx, prob in enumerate(probs):
                if prob >= threshold:
                    tag = idx_to_tag.get(idx, f"tag_{idx}")
                    predictions.append(TagPrediction(tag=tag, score=float(prob)))
        else:
            # Generic tag indices
            for idx, prob in enumerate(probs):
                if prob >= threshold:
                    predictions.append(TagPrediction(tag=f"tag_{idx}", score=float(prob)))
        
        # Sort by score and apply top_k
        predictions.sort(key=lambda x: x.score, reverse=True)
        
        if top_k:
            predictions = predictions[:top_k]
        
        return predictions