| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import streamlit as st | |
| st.set_page_config(layout="wide") | |
| class AppModel: | |
| def __init__(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained("gpt2-medium") | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model = AutoModelForCausalLM.from_pretrained("gpt2-medium") | |
| def generate_plot(self, prompt: str): | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu") | |
| st.write("Input tensor:", inputs) | |
| outputs = self.model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=100, | |
| do_sample=True, | |
| top_k=5, | |
| top_p=0.35, | |
| temperature=0.2, | |
| num_return_sequences=1, | |
| ) | |
| st.write("Generated output:", outputs) | |
| output_string = self.tokenizer.batch_decode(outputs) | |
| st.write("Decoded output:", output_string) | |
| return output_string | |
| model = AppModel() | |
| st.title("Welcome to the GPT Olympics generator") | |
| prompt = st.text_area("Enter the beginning of your plot...") | |
| clicked = st.button("Generate my movie") | |
| if clicked: | |
| st.write("Clicked!") | |
| generated_plot = model.generate_plot(prompt) | |
| st.write("Generated plot:", generated_plot) | |
| if generated_plot: | |
| st.write("Assistant:") | |
| st.markdown(generated_plot[0]) | |
| else: | |
| st.write("No plot generated.") | |
| with open('./style.css') as f: | |
| css = f.read() | |
| st.markdown(f'<style>{css}</style>', unsafe_allow_html=True) | |