| import json |
| import os |
| import time |
| from random import randint |
|
|
| import psutil |
| import streamlit as st |
| import torch |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoModelForSeq2SeqLM, |
| AutoTokenizer, |
| pipeline, |
| set_seed, |
| ) |
|
|
| from generator import GeneratorFactory |
|
|
| device = torch.cuda.device_count() - 1 |
|
|
| TRANSLATION_NL_TO_EN = "translation_en_to_nl" |
|
|
| GENERATOR_LIST = [ |
| { |
| "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl", |
| "desc": "longT5 large nl8 256cc/512beta/512l en->nl", |
| "task": TRANSLATION_NL_TO_EN, |
| }, |
| { |
| "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl", |
| "desc": "longT5 large nl8 512beta/512l en->nl", |
| "task": TRANSLATION_NL_TO_EN, |
| }, |
| { |
| "model_name": "yhavinga/t5-small-24L-ccmatrix-multi", |
| "desc": "T5 small nl24 ccmatrix en->nl", |
| "task": TRANSLATION_NL_TO_EN, |
| }, |
| ] |
|
|
|
|
| def main(): |
| st.set_page_config( |
| page_title="Babel", |
| layout="wide", |
| initial_sidebar_state="expanded", |
| page_icon="📚", |
| ) |
|
|
| if "generators" not in st.session_state: |
| st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST) |
|
|
| generators = st.session_state["generators"] |
|
|
| with open("style.css") as f: |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) |
|
|
| st.sidebar.image("babel.png", width=200) |
| st.sidebar.markdown( |
| """# Babel |
| Vertaal van en naar Engels""" |
| ) |
| model_desc = st.sidebar.selectbox("Model", generators.gpt_descs(), index=1) |
| st.sidebar.title("Parameters:") |
| if "prompt_box" not in st.session_state: |
| |
| st.session_state[ |
| "prompt_box" |
| ] = """It was a wet, gusty night and I had a lonely walk home. By taking the river road, though I hated it, I saved two miles, so I sloshed ahead trying not to think at all. Through the barbed wire fence I could see the racing river. Its black swollen body writhed along with extraordinary swiftness, breathlessly silent, only occasionally making a swishing ripple. I did not enjoy looking at it. I was somehow afraid. |
| |
| And there, at the end of the river road where I swerved off, a figure stood waiting for me, motionless and enigmatic. I had to meet it or turn back. |
| |
| It was a quite young girl, unknown to me, with a hood over her head, and with large unhappy eyes. |
| |
| “My father is very ill,” she said without a word of introduction. “The nurse is frightened. Could you come in and help?”""" |
| st.session_state["text"] = st.text_area( |
| "Enter text", st.session_state.prompt_box, height=300 |
| ) |
| max_length = st.sidebar.number_input( |
| "Lengte van de tekst", |
| value=200, |
| max_value=4096, |
| ) |
| no_repeat_ngram_size = st.sidebar.number_input( |
| "No-repeat NGram size", min_value=1, max_value=5, value=3 |
| ) |
| repetition_penalty = st.sidebar.number_input( |
| "Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1 |
| ) |
| num_return_sequences = st.sidebar.number_input( |
| "Num return sequences", min_value=1, max_value=5, value=1 |
| ) |
| seed_placeholder = st.sidebar.empty() |
| if "seed" not in st.session_state: |
| print(f"Session state does not contain seed") |
| st.session_state["seed"] = 4162549114 |
| print(f"Seed is set to: {st.session_state['seed']}") |
|
|
| seed = seed_placeholder.number_input( |
| "Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"] |
| ) |
|
|
| def set_random_seed(): |
| st.session_state["seed"] = randint(0, 2**32 - 1) |
| seed = seed_placeholder.number_input( |
| "Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"] |
| ) |
| print(f"New random seed set to: {seed}") |
|
|
| if st.button("Set new random seed"): |
| set_random_seed() |
|
|
| if sampling_mode := st.sidebar.selectbox( |
| "select a Mode", index=0, options=["Top-k Sampling", "Beam Search"] |
| ): |
| if sampling_mode == "Beam Search": |
| num_beams = st.sidebar.number_input( |
| "Num beams", min_value=1, max_value=10, value=4 |
| ) |
| length_penalty = st.sidebar.number_input( |
| "Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1 |
| ) |
| params = { |
| "max_length": max_length, |
| "no_repeat_ngram_size": no_repeat_ngram_size, |
| "repetition_penalty": repetition_penalty, |
| "num_return_sequences": num_return_sequences, |
| "num_beams": num_beams, |
| "early_stopping": True, |
| "length_penalty": length_penalty, |
| } |
| else: |
| top_k = st.sidebar.number_input( |
| "Top K", min_value=0, max_value=100, value=50 |
| ) |
| top_p = st.sidebar.number_input( |
| "Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05 |
| ) |
| temperature = st.sidebar.number_input( |
| "Temperature", min_value=0.05, max_value=1.0, value=1.0, step=0.05 |
| ) |
| params = { |
| "max_length": max_length, |
| "no_repeat_ngram_size": no_repeat_ngram_size, |
| "repetition_penalty": repetition_penalty, |
| "num_return_sequences": num_return_sequences, |
| "do_sample": True, |
| "top_k": top_k, |
| "top_p": top_p, |
| "temperature": temperature, |
| } |
|
|
| st.sidebar.markdown( |
| """For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate) |
| and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate). |
| """ |
| ) |
|
|
| def estimate_time(): |
| """Estimate the time it takes to generate the text.""" |
| estimate = max_length / 18 |
| if device == -1: |
| |
| estimate = estimate * (1 + 0.7 * (num_return_sequences - 1)) |
| if sampling_mode == "Beam Search": |
| estimate = estimate * (1.1 + 0.3 * (num_beams - 1)) |
| else: |
| |
| estimate = estimate * (1 + 0.1 * (num_return_sequences - 1)) |
| estimate = 0.5 + estimate / 5 |
| if sampling_mode == "Beam Search": |
| estimate = estimate * (1.0 + 0.1 * (num_beams - 1)) |
| return int(estimate) |
|
|
| if st.button("Run"): |
| estimate = estimate_time() |
|
|
| with st.spinner( |
| text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..." |
| ): |
| memory = psutil.virtual_memory() |
|
|
| for generator in generators: |
| st.subheader(f"Result from {generator}") |
| set_seed(seed) |
| time_start = time.time() |
| result = generator.generate(text=st.session_state.text, **params) |
| time_end = time.time() |
| time_diff = time_end - time_start |
|
|
| for text in result: |
| st.write(text.replace("\n", " \n")) |
| st.write(f"--- generated in {time_diff:.2f} seconds ---") |
|
|
| info = f""" |
| --- |
| *Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB* |
| *Text generated using seed {seed}* |
| """ |
| st.write(info) |
|
|
| params["seed"] = seed |
| params["prompt"] = st.session_state.text |
| params["model"] = generator.model_name |
| params_text = json.dumps(params) |
| print(params_text) |
| st.json(params_text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|