import torch import torch.nn as nn import re import unicodedata from typing import Any, List, Union from transformers import LongformerPreTrainedModel, LongformerModel, AutoTokenizer from .configuration_longformer import LongformerIntentConfig def clean_text(s: Any, normalization: str = "NFKC", flatten_whitespace: bool = True) -> Any: if not isinstance(s, str): return s s = s.replace("\r\n", "\n").replace("\r", "\n") for ch in ["\u2028", "\u2029"]: s = s.replace(ch, "\n") for ch in ["\xa0"]: s = s.replace(ch, " ") for ch in ["\u200b", "\ufeff", "\u180e"]: s = s.replace(ch, "") if normalization != "none": s = unicodedata.normalize(normalization, s) if flatten_whitespace: s = re.sub(r"\s+", " ", s).strip() else: s = re.sub(r"[ \t]+", " ", s).strip() return s class LongformerClassificationHead(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, hidden_states, **kwargs): x = hidden_states[:, 0, :] x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x class LongformerIntentModel(LongformerPreTrainedModel): config_class = LongformerIntentConfig def __init__(self, config): super().__init__(config) self.longformer = LongformerModel(config) self.classifier = LongformerClassificationHead(config) # EL TOKENIZER SE CARGA AQUƍ # Usamos el path guardado en la configuración self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) self.post_init() def forward(self, input_ids, attention_mask=None, global_attention_mask=None, labels=None): outputs = self.longformer( input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask ) logits = self.classifier(outputs[0]) # Devolvemos un diccionario para que el pipeline de HF sea feliz return {"logits": logits} def predict(self, texts: Union[str, List[str]], batch_size: int = 8, device: str = None): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.to(device) self.eval() if isinstance(texts, str): texts = [texts] all_results = [] for i in range(0, len(texts), batch_size): batch_texts = [clean_text(t) for t in texts[i : i + batch_size]] # Usamos el tokenizer interno (self.tokenizer) enc = self.tokenizer( batch_texts, padding=True, truncation=True, max_length=self.config.max_position_embeddings, return_tensors="pt" ).to(device) global_mask = torch.zeros_like(enc["input_ids"]) global_mask[:, 0] = 1 with torch.no_grad(): outputs = self.forward( input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], global_attention_mask=global_mask ) probs = torch.softmax(outputs["logits"], dim=-1).cpu().numpy() for row in probs: pct = (row * 100).round().astype(int) diff = 100 - pct.sum() if diff != 0: pct[pct.argmax()] += diff all_results.append(dict(zip(self.config.intent_columns, pct.tolist()))) return all_results