Spaces:
Build error
Build error
| import streamlit as st | |
| from annotated_text import annotated_text | |
| import torch | |
| from transformers import pipeline | |
| from transformers import AutoModelForTokenClassification, AutoTokenizer | |
| import json | |
| st.set_page_config(layout="wide") | |
| model = AutoModelForTokenClassification.from_pretrained("models/lusa") | |
| tokenizer = AutoTokenizer.from_pretrained("models/lusa", model_max_length=512) | |
| tagger = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy='first') #aggregation_strategy='max' | |
| def aggregate_subwords(input_tokens, labels): | |
| new_inputs = [] | |
| new_labels = [] | |
| current_word = "" | |
| current_label = "" | |
| for i, token in enumerate(input_tokens): | |
| label = labels[i] | |
| # Handle subwords | |
| if token.startswith('##'): | |
| current_word += token[2:] | |
| else: | |
| # Finish previous word | |
| if current_word: | |
| new_inputs.append(current_word) | |
| new_labels.append(current_label) | |
| # Start new word | |
| current_word = token | |
| current_label = label | |
| new_inputs.append(current_word) | |
| new_labels.append(current_label) | |
| return new_inputs, new_labels | |
| def annotateTriggers(line): | |
| line = line.strip() | |
| inputs = tokenizer(line, return_tensors="pt") | |
| input_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| predictions = torch.argmax(logits, dim=2) | |
| predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]] | |
| input_tokens, predicted_token_class = aggregate_subwords(input_tokens,predicted_token_class) | |
| token_labels = [] | |
| current_entity = '' | |
| for i, label in enumerate(predicted_token_class): | |
| token = input_tokens[i] | |
| if label == 'O': | |
| token_labels.append((token, 'O', '')) | |
| current_entity = '' | |
| elif label.startswith('B-'): | |
| current_entity = label[2:] | |
| token_labels.append((token, 'B', current_entity)) | |
| elif label.startswith('I-'): | |
| if current_entity == '': | |
| raise ValueError(f"Invalid label sequence: {predicted_token_class}") | |
| token_labels[-1] = (token_labels[-1][0] + f" {token}", 'I', current_entity) | |
| else: | |
| raise ValueError(f"Invalid label: {label}") | |
| return token_labels[1:-1] | |
| def joinEntities(entities): | |
| joined_entities = [] | |
| i = 0 | |
| while i < len(entities): | |
| curr_entity = entities[i] | |
| if curr_entity['entity'][0] == 'B': | |
| label = curr_entity['entity'][2:] | |
| j = i + 1 | |
| while j < len(entities) and entities[j]['entity'][0] == 'I': | |
| j += 1 | |
| joined_entity = { | |
| 'entity': label, | |
| 'score': max(e['score'] for e in entities[i:j]), | |
| 'index': min(e['index'] for e in entities[i:j]), | |
| 'word': ' '.join(e['word'] for e in entities[i:j]), | |
| 'start': entities[i]['start'], | |
| 'end': entities[j-1]['end'] | |
| } | |
| joined_entities.append(joined_entity) | |
| i = j - 1 | |
| i += 1 | |
| return joined_entities | |
| import pysbd | |
| seg = pysbd.Segmenter(language="es", clean=False) | |
| def sent_tokenize(text): | |
| return seg.segment(text) | |
| def getSentenceIndex(lines,span): | |
| i = 1 | |
| sum = len(lines[0]) | |
| while sum < span: | |
| sum += len(lines[i]) | |
| i = i + 1 | |
| return i - 1 | |
| def generateContext(text, window,span): | |
| lines = sent_tokenize(text) | |
| index = getSentenceIndex(lines,span) | |
| text = " ".join(lines[max(0,index-window):index+window +1]) | |
| return text | |
| def annotateEvents(text,squad,window): | |
| text = text.strip() | |
| ner_results = tagger(text) | |
| #print(ner_results) | |
| #ner_results = joinEntities(ner_results) | |
| i = 0 | |
| #exit() | |
| while i < len(ner_results): | |
| ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("B-") | |
| ner_results[i]["entity"] = ner_results[i]["entity_group"].lstrip("I-") | |
| i = i + 1 | |
| events = [] | |
| for trigger in ner_results: | |
| tipo = trigger["entity_group"] | |
| context = generateContext(text,window,trigger["start"]) | |
| event = { | |
| "trigger":trigger["word"], | |
| "type": tipo, | |
| "score": trigger["score"], | |
| "context": context, | |
| } | |
| events.append(event) | |
| return events | |
| #"A Joana foi atacada pelo João nas ruas do Porto, com uma faca." | |
| st.title('Extract Events') | |
| options = ["O presidente da Federação Haitiana de Futebol, Yves Jean-Bart, foi banido para sempre de toda a atividade ligada ao futebol, por ter sido considerado culpado de abuso sexual sistemático de jogadoras, anunciou hoje a FIFA.", "O barco de pesca Figaro ainda está a flutuar, embora esteja à deriva e ainda a arder.", | |
| "O navio 'Figaro', no qual viajavam 30 tripulantes - 16 angolanos, cinco espanhóis, cinco senegaleses, três peruanos e um do Gana - acionou por telefone o alarme de incêndio a bordo.", "A Polícia Judiciária (PJ) está a investigar o aparecimento de ossadas que foram hoje avistadas pelo proprietário de um terreno na freguesia de Meadela, em Viana do Castelo, disse à Lusa fonte daquela força policial."] | |
| option = st.selectbox( | |
| 'Select examples', | |
| options) | |
| #option = options [index] | |
| line = st.text_area("Insert Text",option) | |
| st.button('Run') | |
| st.sidebar.write("## Hyperparameters :gear:") | |
| window = 1 | |
| if line != "": | |
| st.header("Triggers:") | |
| triggerss = annotateTriggers(line) | |
| annotated_text(*[word[0]+" " if word[1] == 'O' else (word[0]+" ",word[2]) for word in triggerss ]) | |
| eventos_1 = annotateEvents(line,1,window) | |
| eventos_2 = annotateEvents(line,2,window) | |
| for mention1, mention2 in zip(eventos_1,eventos_2): | |
| st.text(f"| Trigger: {mention1['trigger']:20} | Type: {mention1['type']:10} | Score: {str(round(mention1['score'],3)):5} |") | |
| st.markdown("""---""") | |