Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import pickle | |
| import urllib | |
| from transformers import pipeline | |
| from transformers import AutoModelForMaskedLM, AutoTokenizer | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| plot_url = "https://huggingface.co/spaces/fvancesco/test_time_1.1/resolve/main/plot_example.p" | |
| dates = [] | |
| dates.extend([f"18 {m}" for m in range(1,13)]) | |
| dates.extend([f"19 {m}" for m in range(1,13)]) | |
| dates.extend([f"20 {m}" for m in range(1,13)]) | |
| dates.extend([f"21 {m}" for m in range(1,13)]) | |
| months = [x.split(" ")[-1] for x in dates] | |
| model_name = "fvancesco/tmp_date" | |
| model = AutoModelForMaskedLM.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model.eval() | |
| #pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0) | |
| pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer) | |
| def get_mf_dict(text): | |
| # predictions | |
| texts = [] | |
| for d in dates: | |
| texts.append(f"{d} {text}") | |
| tmp_preds = pipe(texts, top_k=50265) | |
| preds = {} | |
| for i in range(len(tmp_preds)): | |
| preds[dates[i]] = tmp_preds[i] | |
| # get preds summary (only top words) | |
| top_n = 5 # top n for each prediction | |
| most_freq_tokens = set() | |
| for d in dates: | |
| tmp = [t['token_str'] for t in preds[d][:top_n]] | |
| most_freq_tokens.update(tmp) | |
| token_prob = {} | |
| for d in dates: | |
| token_prob[d] = {p['token_str']:p['score'] for p in preds[d]} | |
| mf_dict = {p:np.zeros(len(dates)) for p in most_freq_tokens} | |
| c=0 | |
| for d in dates: | |
| for t in most_freq_tokens: | |
| mf_dict[t][c] = token_prob[d][t] | |
| c+=1 | |
| return mf_dict | |
| def plot_time(text): | |
| mf_dict = get_mf_dict(text) | |
| #max_tokens = 10 | |
| fig = plt.figure(figsize=(16,9)) | |
| ax = fig.add_subplot(111) | |
| #fig, ax = plt.subplots(figsize=(16,9)) | |
| x = [i for i in range(len(dates))] | |
| ax.set_xlabel('Month') | |
| ax.set_xlim(0) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(months) | |
| # ax.set_yticks([-1,0,1]) | |
| ax2 = ax.twiny() | |
| ax2.set_xlabel('Year') | |
| ax2.set_xlim(0) | |
| ax2.set_xticks([0,12,24,36,47]) | |
| ax2.set_xticklabels('') | |
| ax2.set_xticks([6,18,30,42,47], minor=True) | |
| ax2.set_xticklabels(['2018','2019','2020','2021',''], minor=True) | |
| ax2.grid() | |
| # plot lines | |
| for k in mf_dict.keys(): | |
| ax.plot(x, mf_dict[k], label = k) | |
| # k = list(mf_dict.keys()) | |
| # for i in range(max_tokens): | |
| # ax.plot(x, mf_dict[k[i]], label = k[i]) | |
| ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5)) | |
| return fig | |
| def add_mask(text): | |
| out = "" | |
| if len(text) == 0 or text[-1] == " ": | |
| out = text+"<mask>" | |
| else: | |
| out = text+" <mask>" | |
| return out | |
| with gr.Blocks() as demo: | |
| text_description=""" | |
| # TimeLMs Demo | |
| This is a demo for **timeLMs**: | |
| - [Github](https://github.com/cardiffnlp/timelms) | |
| - [Paper](https://aclanthology.org/2022.acl-demo.25.pdf) | |
| Input any text with a *\<mask\>* token as in the example, and (the demo does not | |
| use GPUs, and it takes about 1 min). In the graph, we show the probability of | |
| some token candidates for mask over different months. | |
| In this demo we run use a roberta-base model trained on tweets, where the first two | |
| tokens are the year and the month ("21 1" for January 2021). It was trained | |
| for tweets between January 2018 to December 2021). | |
| """ | |
| description = gr.Markdown(text_description) | |
| textbox = gr.Textbox(value="Happy <mask>!", max_lines=1) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Plot") | |
| mask_btn = gr.Button("Add <mask>") | |
| # plot (with starting example already loaded) | |
| f = urllib.request.urlopen(plot_url) | |
| plot_example = pickle.load(f) | |
| plot = gr.Plot(plot_example) | |
| #textbox.change(fn=plot_time, inputs=textbox, outputs=plot) | |
| generate_btn.click(fn=plot_time, inputs=textbox, outputs=plot) | |
| mask_btn.click(fn=add_mask, inputs=textbox, outputs=textbox) | |
| demo.launch(debug=True) |