Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import shutil | |
| import requests | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| def generate(html, entity, website_desc, datasource, year, month, title, prompt): | |
| html_text = "html | " if html == "on" else "" | |
| entity_text = "" | |
| if entity != "": | |
| ent_list = [x.strip() for x in entity.split(',')] | |
| for ent in ent_list: | |
| entity_text = entity_text + " |" + ent + "|" | |
| entity_text = "entity ||| <ENTITY_CHAIN>" + entity_text + " </ENTITY_CHAIN> " | |
| else: | |
| entity_text = "||| " | |
| website_desc_text = "Website Description: " + website_desc + " | " if website_desc != "" else "" | |
| datasource_text = "Datasource: " + datasource + " | " if datasource != "" else "" | |
| year_text = "Year: " + year + " | " if year != "" else "" | |
| month_text = "Month: " + month + " | " if month != "" else "" | |
| title_text = "Title: " + title + " | " if title != "" else "" | |
| final_prompt = html_text + year_text + month_text + website_desc_text + title_text + datasource_text + entity_text + prompt | |
| model = AutoModelForCausalLM.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="checkpoint-30000step") | |
| tokenizer = AutoTokenizer.from_pretrained("bs-modeling-metadata/checkpoints_all_04_23", subfolder="tokenizer", add_prefix_space=True) | |
| bad_words_ids = tokenizer(["<ENTITY_CHAIN>", " </ENTITY_CHAIN> "]).input_ids | |
| inputs = tokenizer(final_prompt, return_tensors="pt") | |
| outputs = model.generate(**inputs, max_new_tokens=128, bad_words_ids=bad_words_ids) | |
| return tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| html = gr.Radio(["on", "off"], label="html", info="turn html as on or off") | |
| entity = gr.Textbox(placeholder="enter a list of comma separated entities or keywords", label="list of entities") | |
| website_desc = gr.Textbox(placeholder="enter a website description", label="website description") | |
| datasource = gr.Textbox(placeholder="enter a datasource", label="datasource") | |
| year = gr.Textbox(placeholder="enter a year", label="year") | |
| month = gr.Textbox(placeholder="enter a month", label="month") | |
| title = gr.Textbox(placeholder="enter a website title", label="website title") | |
| prompt = gr.Textbox(placeholder="enter a prompt", label="prompt") | |
| demo = gr.Interface( | |
| fn=generate, | |
| inputs=[html, entity, website_desc, datasource, year, month, title, prompt], | |
| outputs="text", | |
| ) | |
| demo.launch() |