|
|
import torch |
|
|
import torch.nn as nn |
|
|
import re |
|
|
import unicodedata |
|
|
from typing import Any, List, Union |
|
|
from transformers import LongformerPreTrainedModel, LongformerModel, AutoTokenizer |
|
|
from .configuration_longformer import LongformerIntentConfig |
|
|
|
|
|
def clean_text(s: Any, normalization: str = "NFKC", flatten_whitespace: bool = True) -> Any: |
|
|
if not isinstance(s, str): return s |
|
|
s = s.replace("\r\n", "\n").replace("\r", "\n") |
|
|
for ch in ["\u2028", "\u2029"]: s = s.replace(ch, "\n") |
|
|
for ch in ["\xa0"]: s = s.replace(ch, " ") |
|
|
for ch in ["\u200b", "\ufeff", "\u180e"]: s = s.replace(ch, "") |
|
|
if normalization != "none": s = unicodedata.normalize(normalization, s) |
|
|
if flatten_whitespace: s = re.sub(r"\s+", " ", s).strip() |
|
|
else: s = re.sub(r"[ \t]+", " ", s).strip() |
|
|
return s |
|
|
|
|
|
class LongformerClassificationHead(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
def forward(self, hidden_states, **kwargs): |
|
|
x = hidden_states[:, 0, :] |
|
|
x = self.dropout(x) |
|
|
x = self.dense(x) |
|
|
x = torch.tanh(x) |
|
|
x = self.dropout(x) |
|
|
x = self.out_proj(x) |
|
|
return x |
|
|
|
|
|
class LongformerIntentModel(LongformerPreTrainedModel): |
|
|
config_class = LongformerIntentConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.longformer = LongformerModel(config) |
|
|
self.classifier = LongformerClassificationHead(config) |
|
|
|
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, global_attention_mask=None, labels=None): |
|
|
outputs = self.longformer( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
global_attention_mask=global_attention_mask |
|
|
) |
|
|
logits = self.classifier(outputs[0]) |
|
|
|
|
|
|
|
|
return {"logits": logits} |
|
|
|
|
|
def predict(self, texts: Union[str, List[str]], batch_size: int = 8, device: str = None): |
|
|
if device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.to(device) |
|
|
self.eval() |
|
|
|
|
|
if isinstance(texts, str): texts = [texts] |
|
|
|
|
|
all_results = [] |
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch_texts = [clean_text(t) for t in texts[i : i + batch_size]] |
|
|
|
|
|
|
|
|
enc = self.tokenizer( |
|
|
batch_texts, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=self.config.max_position_embeddings, |
|
|
return_tensors="pt" |
|
|
).to(device) |
|
|
|
|
|
global_mask = torch.zeros_like(enc["input_ids"]) |
|
|
global_mask[:, 0] = 1 |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.forward( |
|
|
input_ids=enc["input_ids"], |
|
|
attention_mask=enc["attention_mask"], |
|
|
global_attention_mask=global_mask |
|
|
) |
|
|
|
|
|
probs = torch.softmax(outputs["logits"], dim=-1).cpu().numpy() |
|
|
|
|
|
for row in probs: |
|
|
pct = (row * 100).round().astype(int) |
|
|
diff = 100 - pct.sum() |
|
|
if diff != 0: |
|
|
pct[pct.argmax()] += diff |
|
|
all_results.append(dict(zip(self.config.intent_columns, pct.tolist()))) |
|
|
|
|
|
return all_results |