|
|
""" |
|
|
Core models for SynCABEL |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import pickle |
|
|
import re |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from huggingface_hub import hf_hub_download |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
LlamaForCausalLM, |
|
|
PretrainedConfig, |
|
|
) |
|
|
|
|
|
from .guided_inference import get_prefix_allowed_tokens_fn |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(levelname)s - %(message)s", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class LLamaSynCABELConfig(PretrainedConfig): |
|
|
model_type = "llama_syncabel" |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
|
|
|
kwargs.setdefault("model_type", "llama") |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
def chunk_it(seq, num): |
|
|
assert num > 0 |
|
|
chunk_len = len(seq) // num |
|
|
chunks = [seq[i * chunk_len : i * chunk_len + chunk_len] for i in range(num)] |
|
|
|
|
|
diff = len(seq) - chunk_len * num |
|
|
for i in range(diff): |
|
|
chunks[i].append(seq[chunk_len * num + i]) |
|
|
|
|
|
return chunks |
|
|
|
|
|
|
|
|
def find_mention(text: str) -> str: |
|
|
match = re.search(r"\[(.*?)\]", text) |
|
|
if match: |
|
|
return match.group(1).strip() |
|
|
else: |
|
|
raise ValueError("No mention found in the text.") |
|
|
|
|
|
|
|
|
def find_sem_group(text: str) -> str: |
|
|
match = re.search(r"\{(.*?)\}", text) |
|
|
if match: |
|
|
return match.group(1).strip() |
|
|
else: |
|
|
raise ValueError("No group type found in the text.") |
|
|
|
|
|
|
|
|
def parse_prediction( |
|
|
outputs: list[str], |
|
|
sem_groups: list[str], |
|
|
verb: str, |
|
|
text_to_code: Optional[dict[str, dict[str, str]]] = None, |
|
|
multiple_answers: bool = False, |
|
|
) -> tuple[list[str], list[str]]: |
|
|
codes = [] |
|
|
predictions = [] |
|
|
for output, group in zip(outputs, sem_groups): |
|
|
splits = output.split(f"] {verb}") |
|
|
if len(splits) > 1 and splits[1].strip(): |
|
|
prediction = splits[1].strip() |
|
|
if text_to_code: |
|
|
if multiple_answers: |
|
|
prediction_list = prediction.split("<SEP>") |
|
|
code_list = [] |
|
|
for pred in prediction_list: |
|
|
code_list.append( |
|
|
text_to_code[group].get(pred.strip(), "NO_CODE") |
|
|
) |
|
|
code = "+".join(code_list) |
|
|
else: |
|
|
code = text_to_code[group].get(prediction, "NO_CODE") |
|
|
else: |
|
|
code = "NO_CODE" |
|
|
else: |
|
|
print( |
|
|
"IndexError: splitting failed or empty prediction, adding empty string as prediction." |
|
|
) |
|
|
print(f"Full text: {output}") |
|
|
prediction = "NO_PREDICTION" |
|
|
code = "NO_CODE" |
|
|
codes.append(code) |
|
|
predictions.append(prediction) |
|
|
return codes, predictions |
|
|
|
|
|
|
|
|
def compute_score(outputs, tokenizer, prefix_len=0): |
|
|
sequences = outputs.sequences |
|
|
scores = outputs.scores |
|
|
|
|
|
N, total_len = sequences.shape |
|
|
T = len(scores) |
|
|
|
|
|
sequences = sequences[:, prefix_len : prefix_len + T] |
|
|
|
|
|
if len(scores) > sequences.size(1): |
|
|
scores = scores[: sequences.size(1)] |
|
|
|
|
|
mask = ( |
|
|
(sequences != tokenizer.pad_token_id) |
|
|
& (sequences != tokenizer.eos_token_id) |
|
|
& (sequences != tokenizer.bos_token_id) |
|
|
) |
|
|
|
|
|
logprob_steps = [] |
|
|
for t, logits in enumerate(scores): |
|
|
log_probs_t = F.log_softmax(logits, dim=-1) |
|
|
token_t = sequences[:, t] |
|
|
idx = torch.arange(N) |
|
|
logprob_steps.append(log_probs_t[idx, token_t]) |
|
|
|
|
|
logprobs = torch.stack(logprob_steps, dim=1) |
|
|
logprobs.masked_fill_(~mask, 0) |
|
|
|
|
|
lengths = mask.sum(dim=1).clamp(min=1) |
|
|
confidence = torch.exp(logprobs.sum(dim=1) / lengths) |
|
|
|
|
|
return confidence.tolist() |
|
|
|
|
|
|
|
|
def skip_undesired_tokens(outputs, tokenizer): |
|
|
sep_token = tokenizer.sep_token if tokenizer.sep_token is not None else None |
|
|
|
|
|
if any("tag" in token for token in tokenizer.all_special_tokens): |
|
|
tokens_to_remove = tokenizer.all_special_tokens[:-3] |
|
|
elif any("{" in token for token in tokenizer.all_special_tokens): |
|
|
tokens_to_remove = tokenizer.all_special_tokens[:-4] |
|
|
else: |
|
|
tokens_to_remove = tokenizer.all_special_tokens |
|
|
|
|
|
if sep_token in tokens_to_remove: |
|
|
tokens_to_remove = [tok for tok in tokens_to_remove if tok != sep_token] |
|
|
|
|
|
cleaned_outputs = [] |
|
|
for sequence in outputs: |
|
|
for token in tokens_to_remove: |
|
|
sequence = sequence.replace(token, "") |
|
|
|
|
|
if sep_token: |
|
|
sequence = re.sub(rf"({re.escape(sep_token)})\s+", r"\1", sequence) |
|
|
|
|
|
cleaned_outputs.append(sequence.strip()) |
|
|
|
|
|
return cleaned_outputs |
|
|
|
|
|
|
|
|
class LLamaSynCABEL(LlamaForCausalLM): |
|
|
config_class = LLamaSynCABELConfig |
|
|
|
|
|
def __init__(self, config, *args, **kwargs): |
|
|
|
|
|
super().__init__(config, *args, **kwargs) |
|
|
|
|
|
|
|
|
self.lang = getattr(config, "lang", "en") |
|
|
self.text_to_code = None |
|
|
self.candidate_trie = None |
|
|
self.tokenizer = None |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
pretrained_model_name_or_path, |
|
|
*args, |
|
|
lang=None, |
|
|
text_to_code_path=None, |
|
|
candidate_trie_path=None, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
custom_kwargs = { |
|
|
"lang": lang, |
|
|
"text_to_code_path": text_to_code_path, |
|
|
"candidate_trie_path": candidate_trie_path, |
|
|
} |
|
|
|
|
|
|
|
|
model = super().from_pretrained( |
|
|
pretrained_model_name_or_path, |
|
|
*args, |
|
|
**{k: v for k, v in kwargs.items() if k not in custom_kwargs}, |
|
|
) |
|
|
|
|
|
|
|
|
model.tokenizer = AutoTokenizer.from_pretrained( |
|
|
pretrained_model_name_or_path, use_fast=True |
|
|
) |
|
|
model.tokenizer.padding_side = "left" |
|
|
|
|
|
|
|
|
if lang is not None: |
|
|
model.lang = lang |
|
|
elif hasattr(model.config, "lang"): |
|
|
model.lang = model.config.lang |
|
|
else: |
|
|
model.lang = "en" |
|
|
|
|
|
logger.info(f"Model language set to: {model.lang}") |
|
|
|
|
|
|
|
|
text_to_code_file_local = ( |
|
|
text_to_code_path |
|
|
if text_to_code_path is not None |
|
|
else os.path.join(pretrained_model_name_or_path, "text_to_code.json") |
|
|
) |
|
|
try: |
|
|
if os.path.exists(text_to_code_file_local): |
|
|
with open(text_to_code_file_local, encoding="utf-8") as f: |
|
|
model.text_to_code = json.load(f) |
|
|
logger.info( |
|
|
f"Loaded text_to_code.json from local path: {text_to_code_file_local}" |
|
|
) |
|
|
else: |
|
|
text_to_code_path_hf = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="text_to_code.json", |
|
|
) |
|
|
with open(text_to_code_path_hf, encoding="utf-8") as f: |
|
|
model.text_to_code = json.load(f) |
|
|
logger.info( |
|
|
f"Loaded text_to_code.json from HF Hub: {text_to_code_path_hf}" |
|
|
) |
|
|
except Exception: |
|
|
logger.warning("text_to_code.json not found (local or HF hub)") |
|
|
model.text_to_code = None |
|
|
|
|
|
|
|
|
candidate_trie_file_local = ( |
|
|
candidate_trie_path |
|
|
if candidate_trie_path is not None |
|
|
else os.path.join(pretrained_model_name_or_path, "candidate_trie.pkl") |
|
|
) |
|
|
try: |
|
|
if os.path.exists(candidate_trie_file_local): |
|
|
with open(candidate_trie_file_local, "rb") as f: |
|
|
model.candidate_trie = pickle.load(f) |
|
|
logger.info( |
|
|
f"Loaded candidate_trie.pkl from local path: {candidate_trie_file_local}" |
|
|
) |
|
|
else: |
|
|
candidate_trie_path_hf = hf_hub_download( |
|
|
repo_id=pretrained_model_name_or_path, |
|
|
filename="candidate_trie.pkl", |
|
|
) |
|
|
with open(candidate_trie_path_hf, "rb") as f: |
|
|
model.candidate_trie = pickle.load(f) |
|
|
logger.info( |
|
|
f"Loaded candidate_trie.pkl from HF Hub: {candidate_trie_path_hf}" |
|
|
) |
|
|
except Exception: |
|
|
logger.warning("candidate_trie.pkl not found (local or HF hub)") |
|
|
model.candidate_trie = None |
|
|
|
|
|
return model |
|
|
|
|
|
def sample( |
|
|
self, |
|
|
sentences: str | list[str], |
|
|
num_beams: int = 5, |
|
|
constrained: bool = True, |
|
|
multiple_answers: bool = False, |
|
|
**kwargs, |
|
|
) -> list[list[dict[str, str]]]: |
|
|
|
|
|
if isinstance(sentences, str): |
|
|
sentences = [sentences] |
|
|
|
|
|
if self.lang == "fr": |
|
|
verb = "est" |
|
|
elif self.lang == "en": |
|
|
verb = "is" |
|
|
elif self.lang == "es": |
|
|
verb = "es" |
|
|
else: |
|
|
raise ValueError(f"Unsupported language: {self.lang}") |
|
|
|
|
|
prefix_templates = [] |
|
|
complete_input_text = [] |
|
|
sem_groups = [] |
|
|
mentions = [] |
|
|
for sent in sentences: |
|
|
sem_group = find_sem_group(sent) |
|
|
mention = find_mention(sent) |
|
|
prefix = f"[{mention}] {verb}" |
|
|
complete_input = f"{sent}<SEP>{prefix}" |
|
|
mentions.append(mention) |
|
|
prefix_templates.append(prefix) |
|
|
complete_input_text.append(complete_input) |
|
|
sem_groups.append(sem_group) |
|
|
|
|
|
input_args = { |
|
|
k: v.to(self.device) |
|
|
for k, v in self.tokenizer.batch_encode_plus( |
|
|
complete_input_text, padding="longest", return_tensors="pt" |
|
|
).items() |
|
|
} |
|
|
|
|
|
prefix_allowed_tokens_fn = None |
|
|
if constrained: |
|
|
if self.candidate_trie is None: |
|
|
raise ValueError( |
|
|
"candidate_trie is not loaded in the model. Use constrained=False." |
|
|
) |
|
|
prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn( |
|
|
self, |
|
|
sentences, |
|
|
prefix_templates, |
|
|
sem_groups, |
|
|
multiple_answers=multiple_answers, |
|
|
) |
|
|
|
|
|
outputs = self.generate( |
|
|
**input_args, |
|
|
max_new_tokens=128, |
|
|
num_beams=num_beams, |
|
|
num_return_sequences=num_beams, |
|
|
output_scores=True, |
|
|
return_dict_in_generate=True, |
|
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
decoded_sequences = self.tokenizer.batch_decode( |
|
|
outputs.sequences, |
|
|
skip_special_tokens=False, |
|
|
clean_up_tokenization_spaces=True, |
|
|
) |
|
|
cleaned_output_sequences = skip_undesired_tokens( |
|
|
decoded_sequences, |
|
|
self.tokenizer, |
|
|
) |
|
|
|
|
|
prefix_len = input_args["input_ids"].size(1) |
|
|
|
|
|
sem_groups = [x for x in sem_groups for _ in range(num_beams)] |
|
|
mentions = [x for x in mentions for _ in range(num_beams)] |
|
|
|
|
|
codes, predictions = parse_prediction( |
|
|
cleaned_output_sequences, |
|
|
sem_groups, |
|
|
verb, |
|
|
self.text_to_code, |
|
|
multiple_answers=multiple_answers, |
|
|
) |
|
|
scores = compute_score(outputs, self.tokenizer, prefix_len=prefix_len) |
|
|
beam_scores = [ |
|
|
float(torch.exp(s)) if num_beams > 1 else float("nan") |
|
|
for s in ( |
|
|
outputs.sequences_scores |
|
|
if num_beams > 1 |
|
|
else [torch.tensor(float("nan"))] * len(scores) |
|
|
) |
|
|
] |
|
|
|
|
|
outputs = chunk_it( |
|
|
[ |
|
|
{ |
|
|
"text": text, |
|
|
"mention": mention, |
|
|
"semantic_group": group, |
|
|
"pred_concept_name": prediction, |
|
|
"pred_concept_code": code, |
|
|
"score": score, |
|
|
"beam_score": beam_score, |
|
|
} |
|
|
for text, score, beam_score, code, prediction, mention, group in zip( |
|
|
cleaned_output_sequences, |
|
|
scores, |
|
|
beam_scores, |
|
|
codes, |
|
|
predictions, |
|
|
mentions, |
|
|
sem_groups, |
|
|
) |
|
|
], |
|
|
len(sentences), |
|
|
) |
|
|
|
|
|
return outputs |
|
|
|
|
|
def encode(self, sentence): |
|
|
return self.tokenizer.encode(sentence, return_tensors="pt")[0] |
|
|
|