""" Core models for SynCABEL """ import json import logging import os import pickle import re from typing import Optional import torch import torch.nn.functional as F from huggingface_hub import hf_hub_download from transformers import ( AutoTokenizer, LlamaForCausalLM, PretrainedConfig, ) from .guided_inference import get_prefix_allowed_tokens_fn logger = logging.getLogger(__name__) logging.basicConfig( level=logging.INFO, format="%(levelname)s - %(message)s", ) # Define a simple config class that inherits from PretrainedConfig class LLamaSynCABELConfig(PretrainedConfig): model_type = "llama_syncabel" def __init__(self, **kwargs): # Ensure it has llama as base kwargs.setdefault("model_type", "llama") super().__init__(**kwargs) def chunk_it(seq, num): assert num > 0 chunk_len = len(seq) // num chunks = [seq[i * chunk_len : i * chunk_len + chunk_len] for i in range(num)] diff = len(seq) - chunk_len * num for i in range(diff): chunks[i].append(seq[chunk_len * num + i]) return chunks def find_mention(text: str) -> str: match = re.search(r"\[(.*?)\]", text) if match: return match.group(1).strip() else: raise ValueError("No mention found in the text.") def find_sem_group(text: str) -> str: match = re.search(r"\{(.*?)\}", text) if match: return match.group(1).strip() else: raise ValueError("No group type found in the text.") def parse_prediction( outputs: list[str], sem_groups: list[str], verb: str, text_to_code: Optional[dict[str, dict[str, str]]] = None, multiple_answers: bool = False, ) -> tuple[list[str], list[str]]: codes = [] predictions = [] for output, group in zip(outputs, sem_groups): splits = output.split(f"] {verb}") # type: ignore if len(splits) > 1 and splits[1].strip(): prediction = splits[1].strip() if text_to_code: if multiple_answers: prediction_list = prediction.split("") # type: ignore code_list = [] for pred in prediction_list: code_list.append( text_to_code[group].get(pred.strip(), "NO_CODE") ) code = "+".join(code_list) else: code = text_to_code[group].get(prediction, "NO_CODE") else: code = "NO_CODE" else: print( "IndexError: splitting failed or empty prediction, adding empty string as prediction." ) print(f"Full text: {output}") # type: ignore prediction = "NO_PREDICTION" code = "NO_CODE" codes.append(code) predictions.append(prediction) return codes, predictions def compute_score(outputs, tokenizer, prefix_len=0): sequences = outputs.sequences scores = outputs.scores N, total_len = sequences.shape T = len(scores) sequences = sequences[:, prefix_len : prefix_len + T] if len(scores) > sequences.size(1): scores = scores[: sequences.size(1)] mask = ( (sequences != tokenizer.pad_token_id) & (sequences != tokenizer.eos_token_id) & (sequences != tokenizer.bos_token_id) ) logprob_steps = [] for t, logits in enumerate(scores): log_probs_t = F.log_softmax(logits, dim=-1) token_t = sequences[:, t] idx = torch.arange(N) logprob_steps.append(log_probs_t[idx, token_t]) logprobs = torch.stack(logprob_steps, dim=1) logprobs.masked_fill_(~mask, 0) lengths = mask.sum(dim=1).clamp(min=1) confidence = torch.exp(logprobs.sum(dim=1) / lengths) return confidence.tolist() def skip_undesired_tokens(outputs, tokenizer): sep_token = tokenizer.sep_token if tokenizer.sep_token is not None else None if any("tag" in token for token in tokenizer.all_special_tokens): tokens_to_remove = tokenizer.all_special_tokens[:-3] elif any("{" in token for token in tokenizer.all_special_tokens): tokens_to_remove = tokenizer.all_special_tokens[:-4] else: tokens_to_remove = tokenizer.all_special_tokens if sep_token in tokens_to_remove: tokens_to_remove = [tok for tok in tokens_to_remove if tok != sep_token] cleaned_outputs = [] for sequence in outputs: for token in tokens_to_remove: sequence = sequence.replace(token, "") if sep_token: sequence = re.sub(rf"({re.escape(sep_token)})\s+", r"\1", sequence) cleaned_outputs.append(sequence.strip()) return cleaned_outputs class LLamaSynCABEL(LlamaForCausalLM): config_class = LLamaSynCABELConfig def __init__(self, config, *args, **kwargs): # Initialize the parent LlamaForCausalLM super().__init__(config, *args, **kwargs) # Store language from config self.lang = getattr(config, "lang", "en") self.text_to_code = None self.candidate_trie = None self.tokenizer = None @classmethod def from_pretrained( cls, pretrained_model_name_or_path, *args, lang=None, text_to_code_path=None, candidate_trie_path=None, **kwargs, ): # Remove custom kwargs before passing to parent custom_kwargs = { "lang": lang, "text_to_code_path": text_to_code_path, "candidate_trie_path": candidate_trie_path, } # Call parent's from_pretrained model = super().from_pretrained( pretrained_model_name_or_path, *args, **{k: v for k, v in kwargs.items() if k not in custom_kwargs}, ) # Set up tokenizer model.tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, use_fast=True ) model.tokenizer.padding_side = "left" # Set language: explicit override > config > default if lang is not None: model.lang = lang elif hasattr(model.config, "lang"): model.lang = model.config.lang else: model.lang = "en" logger.info(f"Model language set to: {model.lang}") # Load text_to_code text_to_code_file_local = ( text_to_code_path if text_to_code_path is not None else os.path.join(pretrained_model_name_or_path, "text_to_code.json") ) try: if os.path.exists(text_to_code_file_local): with open(text_to_code_file_local, encoding="utf-8") as f: model.text_to_code = json.load(f) logger.info( f"Loaded text_to_code.json from local path: {text_to_code_file_local}" ) else: text_to_code_path_hf = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="text_to_code.json", ) with open(text_to_code_path_hf, encoding="utf-8") as f: model.text_to_code = json.load(f) logger.info( f"Loaded text_to_code.json from HF Hub: {text_to_code_path_hf}" ) except Exception: logger.warning("text_to_code.json not found (local or HF hub)") model.text_to_code = None # Load candidate_trie candidate_trie_file_local = ( candidate_trie_path if candidate_trie_path is not None else os.path.join(pretrained_model_name_or_path, "candidate_trie.pkl") ) try: if os.path.exists(candidate_trie_file_local): with open(candidate_trie_file_local, "rb") as f: model.candidate_trie = pickle.load(f) logger.info( f"Loaded candidate_trie.pkl from local path: {candidate_trie_file_local}" ) else: candidate_trie_path_hf = hf_hub_download( repo_id=pretrained_model_name_or_path, filename="candidate_trie.pkl", ) with open(candidate_trie_path_hf, "rb") as f: model.candidate_trie = pickle.load(f) logger.info( f"Loaded candidate_trie.pkl from HF Hub: {candidate_trie_path_hf}" ) except Exception: logger.warning("candidate_trie.pkl not found (local or HF hub)") model.candidate_trie = None return model def sample( self, sentences: str | list[str], # type: ignore num_beams: int = 5, constrained: bool = True, multiple_answers: bool = False, **kwargs, ) -> list[list[dict[str, str]]]: if isinstance(sentences, str): sentences = [sentences] if self.lang == "fr": verb = "est" elif self.lang == "en": verb = "is" elif self.lang == "es": verb = "es" else: raise ValueError(f"Unsupported language: {self.lang}") prefix_templates = [] complete_input_text = [] sem_groups = [] mentions = [] for sent in sentences: sem_group = find_sem_group(sent) mention = find_mention(sent) prefix = f"[{mention}] {verb}" complete_input = f"{sent}{prefix}" mentions.append(mention) prefix_templates.append(prefix) complete_input_text.append(complete_input) sem_groups.append(sem_group) input_args = { k: v.to(self.device) for k, v in self.tokenizer.batch_encode_plus( # type: ignore complete_input_text, padding="longest", return_tensors="pt" ).items() } prefix_allowed_tokens_fn = None if constrained: if self.candidate_trie is None: raise ValueError( "candidate_trie is not loaded in the model. Use constrained=False." ) prefix_allowed_tokens_fn = get_prefix_allowed_tokens_fn( self, sentences, prefix_templates, sem_groups, multiple_answers=multiple_answers, ) outputs = self.generate( **input_args, max_new_tokens=128, num_beams=num_beams, num_return_sequences=num_beams, output_scores=True, return_dict_in_generate=True, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, **kwargs, ) decoded_sequences = self.tokenizer.batch_decode( # type: ignore outputs.sequences, # type: ignore skip_special_tokens=False, clean_up_tokenization_spaces=True, ) cleaned_output_sequences = skip_undesired_tokens( decoded_sequences, self.tokenizer, ) prefix_len = input_args["input_ids"].size(1) sem_groups = [x for x in sem_groups for _ in range(num_beams)] mentions = [x for x in mentions for _ in range(num_beams)] codes, predictions = parse_prediction( cleaned_output_sequences, sem_groups, verb, self.text_to_code, multiple_answers=multiple_answers, ) scores = compute_score(outputs, self.tokenizer, prefix_len=prefix_len) beam_scores = [ float(torch.exp(s)) if num_beams > 1 else float("nan") for s in ( outputs.sequences_scores # type: ignore if num_beams > 1 else [torch.tensor(float("nan"))] * len(scores) ) ] outputs = chunk_it( [ { "text": text, "mention": mention, "semantic_group": group, "pred_concept_name": prediction, "pred_concept_code": code, "score": score, "beam_score": beam_score, } for text, score, beam_score, code, prediction, mention, group in zip( cleaned_output_sequences, scores, beam_scores, codes, predictions, mentions, sem_groups, ) ], len(sentences), ) return outputs def encode(self, sentence): return self.tokenizer.encode(sentence, return_tensors="pt")[0] # type: ignore