Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import SessionState | |
| from mtranslate import translate | |
| from prompts import PROMPT_LIST | |
| import random | |
| import time | |
| from transformers import pipeline, set_seed, AutoConfig, AutoTokenizer, GPT2LMHeadModel, GPT2Tokenizer | |
| import psutil | |
| import torch | |
| import os | |
| from abstract_dataset import AbstractDataset | |
| # st.set_page_config(page_title="Indonesian GPT-2") | |
| mirror_url = "https://abstract-generator.ai-research.id/" | |
| if "MIRROR_URL" in os.environ: | |
| mirror_url = os.environ["MIRROR_URL"] | |
| MODELS = { | |
| "Indonesian Academic Journal - Indonesian GPT-2 Medium": { | |
| "group": "Indonesian Journal", | |
| "name": "cahya/abstract-generator", | |
| "description": "Abstract Generator using Indonesian GPT-2 Medium.", | |
| "text_generator": None, | |
| "tokenizer": None | |
| }, | |
| } | |
| st.sidebar.markdown(""" | |
| <style> | |
| .centeralign { | |
| text-align: center; | |
| } | |
| </style> | |
| <p class="centeralign"> | |
| <img src="https://huggingface.co/spaces/flax-community/gpt2-indonesian/resolve/main/huggingwayang.png"/> | |
| </p> | |
| """, unsafe_allow_html=True) | |
| st.sidebar.markdown(f""" | |
| ___ | |
| <p class="centeralign"> | |
| This is a collection of applications that generates sentences using Indonesian GPT-2 models! | |
| </p> | |
| <p class="centeralign"> | |
| Created by <a href="https://huggingface.co/indonesian-nlp">Indonesian NLP</a> team @2021 | |
| <br/> | |
| <a href="https://github.com/indonesian-nlp/gpt2-app" target="_blank">GitHub</a> | <a href="https://github.com/indonesian-nlp/gpt2-app" target="_blank">Project Report</a> | |
| <br/> | |
| A mirror of the application is available <a href="{mirror_url}" target="_blank">here</a> | |
| </p> | |
| """, unsafe_allow_html=True) | |
| st.sidebar.markdown(""" | |
| ___ | |
| """, unsafe_allow_html=True) | |
| model_type = st.sidebar.selectbox('Model', (MODELS.keys())) | |
| def get_generator(model_name: str): | |
| st.write(f"Loading the GPT2 model {model_name}, please wait...") | |
| special_tokens = AbstractDataset.special_tokens | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.add_special_tokens(special_tokens) | |
| config = AutoConfig.from_pretrained(model_name, | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| sep_token_id=tokenizer.sep_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| output_hidden_states=False) | |
| model = GPT2LMHeadModel.from_pretrained(model_name, config=config) | |
| model.resize_token_embeddings(len(tokenizer)) | |
| return model, tokenizer | |
| # Disable the st.cache for this function due to issue on newer version of streamlit | |
| # @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id}) | |
| def process(text_generator, tokenizer, title: str, keywords: str, text: str, | |
| max_length: int = 200, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95, | |
| temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0): | |
| # st.write("Cache miss: process") | |
| set_seed(seed) | |
| if repetition_penalty == 0.0: | |
| min_penalty = 1.05 | |
| max_penalty = 1.5 | |
| repetition_penalty = max(min_penalty + (1.0-temperature) * (max_penalty-min_penalty), 0.8) | |
| keywords = [keyword.strip() for keyword in keywords.split(",")] | |
| keywords = AbstractDataset.join_keywords(keywords, randomize=False) | |
| special_tokens = AbstractDataset.special_tokens | |
| prompt = special_tokens['bos_token'] + title + \ | |
| special_tokens['sep_token'] + keywords + special_tokens['sep_token'] + text | |
| print(f"title: {title}, keywords: {keywords}, text: {text}") | |
| generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0) | |
| # device = torch.device("cuda") | |
| # generated = generated.to(device) | |
| text_generator.eval() | |
| sample_outputs = text_generator.generate(generated, | |
| do_sample=do_sample, | |
| min_length=200, | |
| max_length=max_length, | |
| top_k=top_k, | |
| top_p=top_p, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| num_return_sequences=1 | |
| ) | |
| result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True) | |
| print(f"result: {result}") | |
| prefix_length = len(title) + len(keywords) | |
| result = result[prefix_length:] | |
| return result | |
| st.title("Indonesian GPT-2 Applications") | |
| prompt_group_name = MODELS[model_type]["group"] | |
| st.header(prompt_group_name) | |
| description = f"This is a bilingual (Indonesian and English) abstract generator using Indonesian GPT-2 Medium. We finetuned it with the Indonesian paper abstract dataset." | |
| st.markdown(description) | |
| model_name = f"Model name: [{MODELS[model_type]['name']}](https://huggingface.co/{MODELS[model_type]['name']})" | |
| st.markdown(model_name) | |
| if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]: | |
| session_state = SessionState.get(prompt=None, prompt_box=None, text=None) | |
| ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"] | |
| prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1) | |
| # Update prompt | |
| if session_state.prompt is None: | |
| session_state.prompt = prompt | |
| elif session_state.prompt is not None and (prompt != session_state.prompt): | |
| session_state.prompt = prompt | |
| session_state.prompt_box = None | |
| else: | |
| session_state.prompt = prompt | |
| # Update prompt box | |
| if session_state.prompt == "Custom": | |
| session_state.prompt_box = "" | |
| session_state.title = "" | |
| session_state.keywords = "" | |
| else: | |
| if session_state.prompt is not None and session_state.prompt_box is None: | |
| session_state.prompt_box = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt]) | |
| session_state.title = st.text_input("Title", session_state.title) | |
| session_state.keywords = st.text_input("Keywords", session_state.keywords) | |
| session_state.text = st.text_area("Prompt", session_state.prompt_box) | |
| max_length = st.sidebar.number_input( | |
| "Maximum length", | |
| value=200, | |
| max_value=512, | |
| help="The maximum length of the sequence to be generated." | |
| ) | |
| temperature = st.sidebar.slider( | |
| "Temperature", | |
| value=0.4, | |
| min_value=0.0, | |
| max_value=2.0 | |
| ) | |
| do_sample = st.sidebar.checkbox( | |
| "Use sampling", | |
| value=True | |
| ) | |
| top_k = 30 | |
| top_p = 0.95 | |
| if do_sample: | |
| top_k = st.sidebar.number_input( | |
| "Top k", | |
| value=top_k, | |
| help="The number of highest probability vocabulary tokens to keep for top-k-filtering." | |
| ) | |
| top_p = st.sidebar.number_input( | |
| "Top p", | |
| value=top_p, | |
| help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher " | |
| "are kept for generation." | |
| ) | |
| seed = st.sidebar.number_input( | |
| "Random Seed", | |
| value=25, | |
| help="The number used to initialize a pseudorandom number generator" | |
| ) | |
| repetition_penalty = 0.0 | |
| automatic_repetition_penalty = st.sidebar.checkbox( | |
| "Automatic Repetition Penalty", | |
| value=True | |
| ) | |
| if not automatic_repetition_penalty: | |
| repetition_penalty = st.sidebar.slider( | |
| "Repetition Penalty", | |
| value=1.0, | |
| min_value=1.0, | |
| max_value=2.0 | |
| ) | |
| for group_name in MODELS: | |
| if MODELS[group_name]["group"] in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]: | |
| MODELS[group_name]["text_generator"], MODELS[group_name]["tokenizer"] = \ | |
| get_generator(MODELS[group_name]["name"]) | |
| if st.button("Run"): | |
| with st.spinner(text="Getting results..."): | |
| memory = psutil.virtual_memory() | |
| st.subheader("Result") | |
| time_start = time.time() | |
| # text_generator = MODELS[model_type]["text_generator"] | |
| result = process(MODELS[model_type]["text_generator"], MODELS[model_type]["tokenizer"], | |
| title=session_state.title, | |
| keywords=session_state.keywords, | |
| text=session_state.text, max_length=int(max_length), | |
| temperature=temperature, do_sample=do_sample, | |
| top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty) | |
| time_end = time.time() | |
| time_diff = time_end-time_start | |
| #result = result[0]["generated_text"] | |
| st.write(result.replace("\n", " \n")) | |
| st.text("Translation") | |
| translation = translate(result, "en", "id") | |
| st.write(translation.replace("\n", " \n")) | |
| # st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*") | |
| info = f""" | |
| *Memory: {memory.total/(1024*1024*1024):.2f}GB, used: {memory.percent}%, available: {memory.available/(1024*1024*1024):.2f}GB* | |
| *Text generated in {time_diff:.5} seconds* | |
| """ | |
| st.write(info) | |
| # Reset state | |
| session_state.prompt = None | |
| session_state.prompt_box = None | |