| |
| |
| |
|
|
| |
| from typing import List |
|
|
| |
| import numpy as np |
| import torch |
|
|
| |
| from kgs_binding.relation_mapper_builder import RelationsMapperBuilder |
| from kgs_binding.kg_qa_binding_utils import load_kg_handler |
| from data.relation_utils import clean_relations |
| from model_utils import create_layers_head_mask |
|
|
| from transformers import ( |
| BartForConditionalGeneration, |
| BartTokenizer, |
| BartConfig, |
| DisjunctiveConstraint, |
| ) |
|
|
| from utils import get_jump_chunks |
|
|
| |
| |
| |
|
|
| |
| |
| |
| from custom_tokenizer import BartCustomTokenizerFast |
| from custom_bart import BartCustomConfig, BartCustomForConditionalGeneration |
| from utils import get_device, KGType, Model_Type |
|
|
| from kgs_binding.kg_base_wrapper import KGBaseHandler |
| from kgs_binding.swow_handler import SwowHandler |
| from kgs_binding.conceptnet_handler import ConceptNetHandler |
|
|
| class Inference: |
| def __init__(self, model_path:str, max_length=32): |
| self.device = get_device() |
| self.tokenizer = self.prepare_tokenizer() |
| self.model = self.prepare_model(model_path) |
| self.max_length = max_length |
|
|
| def prepare_tokenizer(self): |
| tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') |
| return tokenizer |
|
|
| def prepare_model(self, model_path): |
| config = BartConfig.from_pretrained(model_path) |
| model = BartForConditionalGeneration.from_pretrained(model_path, config=config).to(self.device) |
| model.eval() |
| return model |
|
|
| def pre_process_context(self, context): |
| context = context.lower() |
| context_tokenized = self.tokenizer(context, padding='max_length', |
| truncation='longest_first', max_length=self.max_length, |
| return_tensors="pt", |
| ) |
| return context_tokenized |
|
|
| def generate_based_on_context(self, context): |
| model_input = self.pre_process_context(context) |
| generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device), |
| attention_mask=model_input["attention_mask"].to(self.device), |
| min_length=1, |
| max_length=self.max_length, |
| do_sample=True, |
| early_stopping=True, |
| num_beams=4, |
| temperature=1.0, |
| top_k=None, |
| top_p=None, |
| |
| no_repeat_ngram_size=2, |
| num_return_sequences=1, |
| return_dict_in_generate=True, |
| output_attentions=True, |
| output_scores=True) |
| |
| response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True, |
| clean_up_tokenization_spaces=True) |
| encoder_attentions = generated_answers_encoded['encoder_attentions'] |
| return response, encoder_attentions, model_input |
|
|
| def prepare_context_for_visualization(self, context): |
| examples = [] |
| response, encoder_outputs, model_input = self.generate_based_on_context(context) |
| encoder_outputs = torch.stack(encoder_outputs) |
| n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size() |
| print(encoder_outputs.size()) |
| encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt) |
| for i, ex in enumerate(encoder_attentions): |
| d = {} |
| indices = model_input['input_ids'][i].detach().cpu() |
| all_tokens = self.tokenizer.convert_ids_to_tokens(indices) |
| useful_indeces = indices != self.tokenizer.pad_token_id |
| all_tokens = np.array(all_tokens)[useful_indeces] |
| all_tokens = [tok.replace('Ġ', '') for tok in all_tokens] |
| d['words'] = all_tokens |
| d['attentions'] = ex.detach().cpu().numpy() |
| examples.append(d) |
| print(d['words']) |
| return response, examples |
|
|
| class RelationsInference: |
| def __init__(self, model_path:str, kg_type: KGType, model_type:Model_Type, max_length=32): |
| self.device = get_device() |
| kg_handler: KGBaseHandler = load_kg_handler(kg_type) |
| self.kg_handler = kg_handler |
| relation_names = kg_handler.get_relation_types() |
| self.tokenizer = self.prepare_tokenizer(relation_names, model_type) |
| self.simple_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') |
| self.model, self.config = self.prepare_model(relation_names, model_path, model_type) |
| self.relation_mapper_builder = RelationsMapperBuilder(knowledge=kg_handler) |
| self.max_length = max_length |
|
|
| def prepare_tokenizer(self, relation_names: List[str], model_type:Model_Type): |
| tokenizer = BartCustomTokenizerFast.from_pretrained('facebook/bart-large') |
| tokenizer.set_known_relation_names(relation_names) |
| tokenizer.set_operation_mode(there_is_difference_between_relations=model_type.there_is_difference_between_relations()) |
| return tokenizer |
|
|
| def prepare_model(self, relation_names: List[str], model_path, model_type:Model_Type): |
| config = BartCustomConfig.from_pretrained(model_path, revision='master') |
| print('config.heads_mask:', config.heads_mask) |
| if config.num_relation_kinds is None: |
| config.num_relation_kinds = len(relation_names) |
| if config.is_simple_mask_commonsense is None: |
| config.is_simple_mask_commonsense = model_type.is_simple_mask_commonsense() |
| if config.heads_mask is None: |
| config.heads_mask = create_layers_head_mask(config) |
| model = BartCustomForConditionalGeneration.from_pretrained(model_path, config=config, revision='master').to(self.device) |
| model.eval() |
| return model, config |
|
|
| def pre_process_context(self, context): |
| context = context.lower() |
| |
| commonsense_relations = self.relation_mapper_builder.get_relations_mapping_complex(context=[context], clear_common_wds=True) |
| |
| commonsense_relation = clean_relations(commonsense_relations)[0] |
| |
| print(commonsense_relation) |
| context_tokenized = self.tokenizer(context, padding='max_length', |
| truncation='longest_first', max_length=self.max_length, |
| return_tensors="pt", return_offsets_mapping=True, |
| input_commonsense_relations=commonsense_relation, |
| ) |
| return context_tokenized |
|
|
| def get_relations_information(self, phrase_generated): |
| all_concepts = self.relation_mapper_builder.get_kg_concepts_from_context([phrase_generated], clear_common_wds=True)[0] |
| words = phrase_generated.strip().split(' ') |
| concepts_with_relations = self.relation_mapper_builder.get_concepts_from_context(phrase_generated, clear_common_wds=True) |
| concepts_with_no_relations = list(set(all_concepts).difference(concepts_with_relations)) |
| |
| print("====== RELATIONS SUMMARY ======") |
| print('phrase_generated:', phrase_generated) |
| print('words:', words) |
| print('all_concepts:', all_concepts) |
| print('concepts_with_relations:', concepts_with_relations) |
| print('without_relations:', concepts_with_no_relations) |
| print("\n== STATS:") |
| print('n_words:', len(words)) |
| print('n_concepts:', len(all_concepts)) |
| print('n_concepts_with_relations:', len(concepts_with_relations)) |
| print('n_c_without_relations:', len(concepts_with_no_relations)) |
| print("====== ================= ======") |
| return words, all_concepts, concepts_with_relations, concepts_with_no_relations |
|
|
| def remove_subsets(self, l): |
| l2 = l[:] |
| for m in l: |
| for n in l: |
| if set(m).issubset(set(n)) and m != n: |
| l2.remove(m) |
| break |
| return l2 |
|
|
| def generate_based_on_context(self, context, use_kg=False): |
| model_input = self.pre_process_context(context) |
| |
| gen_kwargs = {} |
| if "input_commonsense_relations" in model_input: |
| |
| gen_kwargs["relation_inputs"] = model_input.get("input_commonsense_relations").to(self.device) |
|
|
| constraints = None |
| if use_kg: |
| constraints = [] |
| concepts_from_context = self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True) |
| useful_concepts = [self.relation_mapper_builder.knowledge.get_related_concepts(concept) for concept in concepts_from_context] |
| if not useful_concepts: |
| useful_concepts = [self.kg_handler.get_related_concepts(concept) for concept in concepts_from_context] |
| useful_concepts = [[f'{phrase}' for phrase in concepts] for concepts in useful_concepts] |
| |
| |
| |
| if concepts_from_context: |
| for context_concept, neighbour_concepts in zip(concepts_from_context, useful_concepts): |
| print('neighbour:', neighbour_concepts[:20]) |
| |
| |
| flexible_words = [word for word in neighbour_concepts if word not in context_concept] |
| flexible_words_ids: List[List[int]] = self.simple_tokenizer(flexible_words, add_prefix_space=True,add_special_tokens=False).input_ids |
| flexible_words_ids = self.remove_subsets(flexible_words_ids) |
| |
| |
| flexible_words_ids = flexible_words_ids[:10] |
| print('flexible_words_ids:', flexible_words_ids[:3]) |
| constraint = DisjunctiveConstraint(flexible_words_ids) |
| constraints.append(constraint) |
| else: |
| constraints = None |
|
|
| generated_answers_encoded = self.model.generate(input_ids=model_input["input_ids"].to(self.device), |
| attention_mask=model_input["attention_mask"].to(self.device), |
| constraints=constraints, |
| min_length=1, |
| max_length=self.max_length, |
| do_sample=False, |
| early_stopping=True, |
| num_beams=8, |
| temperature=1.0, |
| top_k=None, |
| top_p=None, |
| |
| no_repeat_ngram_size=2, |
| num_return_sequences=1, |
| return_dict_in_generate=True, |
| output_attentions=True, |
| output_scores=True, |
| **gen_kwargs, |
| ) |
| |
| response = self.tokenizer.batch_decode(generated_answers_encoded['sequences'], skip_special_tokens=True, |
| clean_up_tokenization_spaces=True) |
| encoder_attentions = generated_answers_encoded['encoder_attentions'] |
| return response, encoder_attentions, model_input |
|
|
| def get_related_concepts_list(self, knowledge, list_concepts): |
| other_concepts = [] |
| for concept in list_concepts: |
| other_near_concepts = knowledge.get_related_concepts(concept) |
| other_concepts.extend(other_near_concepts) |
| return other_concepts |
|
|
|
|
| def generate_contrained_based_on_context(self, contexts, use_kg=True, max_concepts=1): |
| model_inputs = [self.pre_process_context(context) for context in contexts] |
| constraints = None |
| if use_kg: |
| constraints = [] |
| concepts_from_contexts = [self.relation_mapper_builder.get_concepts_from_context(context=context, clear_common_wds=True) for context in contexts] |
| neighbours_contexts = [] |
| if not neighbours_contexts: |
| neighbours_contexts = [self.get_related_concepts_list(self.kg_handler, context) for context in concepts_from_contexts] |
| all_constraints = [] |
| for context_neighbours in neighbours_contexts: |
| |
| |
| context_neighbours = [f' {concept}' for concept in context_neighbours if len(concept) > 3] |
| n_size_chuncks = len(context_neighbours) // max_concepts |
| n_size_chuncks = n_size_chuncks if n_size_chuncks > 0 else 1 |
| sub_concepts_collection = list(get_jump_chunks(context_neighbours, jump=n_size_chuncks)) |
| constraints = [] |
| for sub_concepts in sub_concepts_collection[:max_concepts]: |
| flexible_words_ids: List[List[int]] = self.tokenizer(sub_concepts, |
| add_special_tokens=False).input_ids |
| |
| flexible_words_ids = [[word_ids[0]] for word_ids in flexible_words_ids] |
| disjunctive_set = list(map(list, set(map(frozenset, flexible_words_ids)))) |
| if not any(disjunctive_set): |
| continue |
| constraint = DisjunctiveConstraint(disjunctive_set) |
| constraints.append(constraint) |
| if not any(constraints): |
| constraints = None |
| all_constraints.append(constraints) |
| else: |
| all_constraints = None |
| if not all_constraints: |
| all_constraints = None |
|
|
| generated_answers_encoded = [] |
| encoder_attentions_list = [] |
| for i, contraints in enumerate(all_constraints): |
| |
| gen_kwargs = {} |
| inputs = model_inputs[i] |
| if "input_commonsense_relations" in inputs: |
| |
| gen_kwargs["relation_inputs"] = inputs.get("input_commonsense_relations").to(self.device) |
| |
| gen = self.model.generate(input_ids=inputs["input_ids"].to(self.device), |
| attention_mask=inputs["attention_mask"].to(self.device), |
| constraints=constraints, |
| min_length=1, |
| max_length=self.max_length, |
| do_sample=False, |
| early_stopping=True, |
| num_beams=8, |
| temperature=1.0, |
| top_k=None, |
| top_p=None, |
| |
| no_repeat_ngram_size=2, |
| num_return_sequences=1, |
| return_dict_in_generate=True, |
| output_attentions=True, |
| output_scores=True, |
| **gen_kwargs, |
| ) |
| |
| |
| generated_answers_encoded.append(gen['sequences'][0].detach().cpu()) |
| encoder_attentions_list.append(gen['encoder_attentions'][0].detach().cpu()) |
| |
| text_results = self.tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True, |
| clean_up_tokenization_spaces=True) |
| return text_results, encoder_attentions_list, model_inputs |
|
|
| def prepare_context_for_visualization(self, context): |
| examples, relations = [], [] |
| response, encoder_outputs, model_input = self.generate_based_on_context(context) |
| input_commonsense_relations = model_input.get("input_commonsense_relations") |
| encoder_outputs = torch.stack(encoder_outputs) |
| n_layers, batch_size, n_heads, src, tgt = encoder_outputs.size() |
| print(encoder_outputs.size()) |
| encoder_attentions = encoder_outputs.view(batch_size, n_layers, n_heads, src, tgt) |
| for i, ex in enumerate(encoder_attentions): |
| d = {} |
| indices = model_input['input_ids'][i].detach().cpu() |
| all_tokens = self.tokenizer.convert_ids_to_tokens(indices) |
| useful_indeces = indices != self.tokenizer.pad_token_id |
| all_tokens = np.array(all_tokens)[useful_indeces] |
| all_tokens = [tok.replace('Ġ', '') for tok in all_tokens] |
| d['words'] = all_tokens |
| d['attentions'] = ex.detach().cpu().numpy() |
| examples.append(d) |
| relations.append(input_commonsense_relations[i]) |
| print(d['words']) |
| return response, examples, relations |
|
|