Spaces:
Build error
Build error
| import gradio as gr | |
| import numpy as np | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from camel_tools.data import CATALOGUE | |
| from camel_tools.tagger.default import DefaultTagger | |
| from camel_tools.disambig.bert import BERTUnfactoredDisambiguator | |
| def predict_label(text): | |
| ip = text.split() | |
| ip_len = [len(ip)] | |
| span_scores = extract_spannet_scores(span_model,ip,ip_len) | |
| span_pooled_scores = pool_span_scores(span_scores, ip_len) | |
| pos_tags = tagger.tag(ip) | |
| msa_span_scores = extract_spannet_scores(msa_span_model,ip,ip_len,pos=pos_tags) | |
| msa_pooled_scores = pool_span_scores(msa_span_scores, ip_len) | |
| ensemble_span_scores = [score for scores in [span_scores, msa_span_scores] for score in scores] | |
| ensemble_pooled_scores = pool_span_scores(ensemble_span_scores, ip_len) | |
| ent_scores = extract_ent_scores(entity_model,ip,ensemble_pooled_scores) | |
| combined_sequences, ent_pred_tags = pool_ent_scores(ent_scores, ip_len) | |
| # ops = [[i,o] for i,o in zip(ip,combined_sequences[-1])] | |
| return combined_sequences | |
| if __name__ == '__main__': | |
| space_key = os.environ.get('key') | |
| filenames = ['network.py', 'layers.py', 'utils.py', | |
| 'representation.py', 'predict.py', 'validate.py'] | |
| for file in filenames: | |
| hf_hub_download('nehalelkaref/stagedNER', | |
| filename=file, | |
| local_dir='src', | |
| token=space_key) | |
| CATALOGUE.download_package("all", | |
| recursive=True, | |
| force=True, | |
| print_status=True) | |
| from src.predict import extract_spannet_scores,extract_ent_scores,pool_span_scores,pool_ent_scores | |
| from src.network import SpanNet, EntNet | |
| from src.validate import entities_from_token_classes | |
| diasmbig = BERTUnfactoredDisambiguator.pretrained('msa') | |
| tagger = DefaultTagger(diasmbig, 'pos') | |
| # entity_path = 'nehalelkaref/entity_model/entity.msa.model' | |
| # span_path = 'models/span.model' | |
| # msa_span_path = 'new_models/msa.best.model' | |
| # entity_path= 'models/entity.msa.model' | |
| span_model = SpanNet.load_model(span_path) | |
| msa_span_model = SpanNet.load_model(msa_span_path) | |
| entity_model = EntNet.load_model(entity_path) | |
| with gr.Blocks(theme='finlaymacklon/smooth_slate') as iface: | |
| example_input=gr.Textbox(label="Input Example", lines=3) | |
| prediction=gr.Text(label="Predicted Entities") | |
| gr.Interface(fn=predict_label, inputs=example_input, | |
| outputs=prediction,theme="smooth_slate", | |
| title="Flat Entity Classification for Levantine Arabic") | |
| gr.Examples( | |
| examples=["النشرة الإخبارية الصادرة عن الأونروا رقم 113 (1986/1/8).", | |
| "صورة لمدينة أريحا القديمة :تل السلطان", | |
| "صورة اطفال مخيم للاجئين الفلسطينيين في لبنان"], | |
| inputs= example_input) | |
| iface.launch(show_api=False) | |