Wilson commited on
Commit
01b24aa
verified
1 Parent(s): 54c9f7c

Create modeling_longformer.py

Browse files
Files changed (1) hide show
  1. modeling_longformer.py +79 -0
modeling_longformer.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import re
4
+ import unicodedata
5
+ from typing import Any, List, Union
6
+ from transformers import LongformerPreTrainedModel, LongformerModel
7
+ from .configuration_longformer import LongformerIntentConfig
8
+
9
+ # --- L贸gica de Limpieza Integrada ---
10
+ def clean_text(s: Any, normalization: str = "NFKC", flatten_whitespace: bool = True) -> Any:
11
+ if not isinstance(s, str): return s
12
+ s = s.replace("\r\n", "\n").replace("\r", "\n")
13
+ for ch in ["\u2028", "\u2029"]: s = s.replace(ch, "\n")
14
+ for ch in ["\xa0"]: s = s.replace(ch, " ")
15
+ for ch in ["\u200b", "\ufeff", "\u180e"]: s = s.replace(ch, "")
16
+ if normalization != "none": s = unicodedata.normalize(normalization, s)
17
+ if flatten_whitespace: s = re.sub(r"\s+", " ", s).strip()
18
+ else: s = re.sub(r"[ \t]+", " ", s).strip()
19
+ return s
20
+
21
+ # --- Arquitectura del Modelo ---
22
+ class LongformerIntentModel(LongformerPreTrainedModel):
23
+ config_class = LongformerIntentConfig
24
+
25
+ def __init__(self, config):
26
+ super().__init__(config)
27
+ self.longformer = LongformerModel(config)
28
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
29
+ self.post_init()
30
+
31
+ def forward(self, input_ids, attention_mask=None, global_attention_mask=None, labels=None):
32
+ outputs = self.longformer(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
33
+ sequence_output = outputs[0][:, 0, :] # Token <s>
34
+ logits = self.classifier(sequence_output)
35
+ return type('Outputs', (object,), {'logits': logits})()
36
+
37
+ # --- API de Inferencia (Tu infer.py adaptado) ---
38
+ def predict(self, texts: Union[str, List[str]], tokenizer, batch_size: int = 8, device: str = None):
39
+ if device is None:
40
+ device = "cuda" if torch.cuda.is_available() else "cpu"
41
+
42
+ self.to(device)
43
+ self.eval()
44
+
45
+ if isinstance(texts, str): texts = [texts]
46
+
47
+ all_results = []
48
+ for i in range(0, len(texts), batch_size):
49
+ batch_texts = [clean_text(t) for t in texts[i : i + batch_size]]
50
+
51
+ enc = tokenizer(
52
+ batch_texts,
53
+ padding=True,
54
+ truncation=True,
55
+ max_length=self.config.max_position_embeddings,
56
+ return_tensors="pt"
57
+ ).to(device)
58
+
59
+ # Global attention en el token <s> (铆ndice 0)
60
+ global_mask = torch.zeros_like(enc["input_ids"])
61
+ global_mask[:, 0] = 1
62
+
63
+ with torch.no_grad():
64
+ outputs = self.forward(
65
+ input_ids=enc["input_ids"],
66
+ attention_mask=enc["attention_mask"],
67
+ global_attention_mask=global_mask
68
+ )
69
+
70
+ probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy()
71
+
72
+ for row in probs:
73
+ pct = (row * 100).round().astype(int)
74
+ diff = 100 - pct.sum()
75
+ if diff != 0:
76
+ pct[pct.argmax()] += diff
77
+ all_results.append(dict(zip(self.config.intent_columns, pct.tolist())))
78
+
79
+ return all_results