test-model / modeling_longformer.py
WilC90's picture
Update modeling_longformer.py
2d226d2 verified
import torch
import torch.nn as nn
import re
import unicodedata
from typing import Any, List, Union
from transformers import LongformerPreTrainedModel, LongformerModel, AutoTokenizer
from .configuration_longformer import LongformerIntentConfig
def clean_text(s: Any, normalization: str = "NFKC", flatten_whitespace: bool = True) -> Any:
if not isinstance(s, str): return s
s = s.replace("\r\n", "\n").replace("\r", "\n")
for ch in ["\u2028", "\u2029"]: s = s.replace(ch, "\n")
for ch in ["\xa0"]: s = s.replace(ch, " ")
for ch in ["\u200b", "\ufeff", "\u180e"]: s = s.replace(ch, "")
if normalization != "none": s = unicodedata.normalize(normalization, s)
if flatten_whitespace: s = re.sub(r"\s+", " ", s).strip()
else: s = re.sub(r"[ \t]+", " ", s).strip()
return s
class LongformerClassificationHead(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, hidden_states, **kwargs):
x = hidden_states[:, 0, :]
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class LongformerIntentModel(LongformerPreTrainedModel):
config_class = LongformerIntentConfig
def __init__(self, config):
super().__init__(config)
self.longformer = LongformerModel(config)
self.classifier = LongformerClassificationHead(config)
# EL TOKENIZER SE CARGA AQUÍ
# Usamos el path guardado en la configuración
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
self.post_init()
def forward(self, input_ids, attention_mask=None, global_attention_mask=None, labels=None):
outputs = self.longformer(
input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attention_mask
)
logits = self.classifier(outputs[0])
# Devolvemos un diccionario para que el pipeline de HF sea feliz
return {"logits": logits}
def predict(self, texts: Union[str, List[str]], batch_size: int = 8, device: str = None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
self.eval()
if isinstance(texts, str): texts = [texts]
all_results = []
for i in range(0, len(texts), batch_size):
batch_texts = [clean_text(t) for t in texts[i : i + batch_size]]
# Usamos el tokenizer interno (self.tokenizer)
enc = self.tokenizer(
batch_texts,
padding=True,
truncation=True,
max_length=self.config.max_position_embeddings,
return_tensors="pt"
).to(device)
global_mask = torch.zeros_like(enc["input_ids"])
global_mask[:, 0] = 1
with torch.no_grad():
outputs = self.forward(
input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"],
global_attention_mask=global_mask
)
probs = torch.softmax(outputs["logits"], dim=-1).cpu().numpy()
for row in probs:
pct = (row * 100).round().astype(int)
diff = 100 - pct.sum()
if diff != 0:
pct[pct.argmax()] += diff
all_results.append(dict(zip(self.config.intent_columns, pct.tolist())))
return all_results