| | import torch |
| | import onnxruntime as ort |
| | from transformers import XLMRobertaTokenizer |
| | import numpy as np |
| | import os |
| |
|
| | class OptimizedToxicityClassifier: |
| | """High-performance toxicity classifier for production""" |
| | |
| | def __init__(self, onnx_path=None, pytorch_path=None, device='cuda'): |
| | self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') |
| | |
| | |
| | self.lang_map = { |
| | 'en': 0, 'ru': 1, 'tr': 2, 'es': 3, |
| | 'fr': 4, 'it': 5, 'pt': 6 |
| | } |
| | |
| | |
| | self.label_names = [ |
| | 'toxic', 'severe_toxic', 'obscene', |
| | 'threat', 'insult', 'identity_hate' |
| | ] |
| | |
| | |
| | if onnx_path and os.path.exists(onnx_path): |
| | |
| | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] \ |
| | if device == 'cuda' and 'CUDAExecutionProvider' in ort.get_available_providers() \ |
| | else ['CPUExecutionProvider'] |
| | |
| | self.session = ort.InferenceSession(onnx_path, providers=providers) |
| | self.use_onnx = True |
| | print(f"Loaded ONNX model from {onnx_path}") |
| | |
| | elif pytorch_path: |
| | from model.language_aware_transformer import LanguageAwareTransformer |
| | |
| | |
| | if os.path.isdir(pytorch_path): |
| | |
| | latest_path = os.path.join(pytorch_path, 'latest') |
| | if os.path.islink(latest_path) and os.path.exists(latest_path): |
| | checkpoint_dir = latest_path |
| | else: |
| | |
| | checkpoint_dirs = [d for d in os.listdir(pytorch_path) if d.startswith('checkpoint_epoch')] |
| | if checkpoint_dirs: |
| | checkpoint_dirs.sort() |
| | checkpoint_dir = os.path.join(pytorch_path, checkpoint_dirs[-1]) |
| | else: |
| | raise ValueError(f"No checkpoint directories found in {pytorch_path}") |
| | |
| | |
| | model_file = None |
| | potential_files = ['pytorch_model.bin', 'model.pt', 'model.pth'] |
| | for file in potential_files: |
| | candidate = os.path.join(checkpoint_dir, file) |
| | if os.path.exists(candidate): |
| | model_file = candidate |
| | break |
| | |
| | if not model_file: |
| | raise FileNotFoundError(f"No model file found in {checkpoint_dir}") |
| | |
| | print(f"Using model from checkpoint: {checkpoint_dir}") |
| | model_path = model_file |
| | else: |
| | |
| | model_path = pytorch_path |
| | |
| | self.model = LanguageAwareTransformer(num_labels=6) |
| | self.model.load_state_dict(torch.load(model_path, map_location=device)) |
| | self.model.to(device) |
| | self.model.eval() |
| | self.use_onnx = False |
| | self.device = device |
| | print(f"Loaded PyTorch model from {model_path}") |
| | else: |
| | raise ValueError("Either onnx_path or pytorch_path must be provided") |
| | |
| | def predict(self, texts, langs=None, batch_size=8): |
| | """ |
| | Predict toxicity for a list of texts |
| | |
| | Args: |
| | texts: List of text strings |
| | langs: List of language codes (e.g., 'en', 'fr') |
| | batch_size: Batch size for processing |
| | |
| | Returns: |
| | List of dictionaries with toxicity predictions |
| | """ |
| | results = [] |
| | |
| | |
| | if langs is None: |
| | langs = ['en'] * len(texts) |
| | |
| | |
| | lang_ids = [self.lang_map.get(lang, 0) for lang in langs] |
| | |
| | |
| | for i in range(0, len(texts), batch_size): |
| | batch_texts = texts[i:i+batch_size] |
| | batch_langs = lang_ids[i:i+batch_size] |
| | |
| | |
| | inputs = self.tokenizer( |
| | batch_texts, |
| | padding=True, |
| | truncation=True, |
| | max_length=128, |
| | return_tensors='pt' |
| | ) |
| | |
| | |
| | if self.use_onnx: |
| | |
| | ort_inputs = { |
| | 'input_ids': inputs['input_ids'].numpy(), |
| | 'attention_mask': inputs['attention_mask'].numpy(), |
| | 'lang_ids': np.array(batch_langs, dtype=np.int64) |
| | } |
| | ort_outputs = self.session.run(None, ort_inputs) |
| | probabilities = 1 / (1 + np.exp(-ort_outputs[0])) |
| | else: |
| | |
| | with torch.no_grad(): |
| | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| | lang_tensor = torch.tensor(batch_langs, dtype=torch.long, device=self.device) |
| | outputs = self.model( |
| | input_ids=inputs['input_ids'], |
| | attention_mask=inputs['attention_mask'], |
| | lang_ids=lang_tensor, |
| | mode='inference' |
| | ) |
| | probabilities = outputs['probabilities'].cpu().numpy() |
| | |
| | |
| | for j, (text, lang, probs) in enumerate(zip(batch_texts, langs[i:i+batch_size], probabilities)): |
| | |
| | lang_thresholds = { |
| | 'default': [0.60, 0.54, 0.60, 0.48, 0.60, 0.50] |
| | |
| | } |
| | |
| | |
| | thresholds = lang_thresholds.get(lang, lang_thresholds['default']) |
| | is_toxic = (probs >= np.array(thresholds)).astype(bool) |
| | |
| | result = { |
| | 'text': text, |
| | 'language': lang, |
| | 'probabilities': { |
| | label: float(prob) for label, prob in zip(self.label_names, probs) |
| | }, |
| | 'is_toxic': bool(is_toxic.any()), |
| | 'toxic_categories': [ |
| | self.label_names[k] for k in range(len(is_toxic)) if is_toxic[k] |
| | ] |
| | } |
| | results.append(result) |
| | |
| | return results |