File size: 1,524 Bytes
7a83cb9 e7ed614 7a83cb9 c9686c3 7a83cb9 c9686c3 7a83cb9 c9686c3 7a83cb9 c9686c3 7a83cb9 783e254 7a83cb9 5c083c3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | from transformers import AutoTokenizer, AutoModelForCausalLM
import streamlit as st
st.set_page_config(layout="wide")
class AppModel:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
def generate_plot(self, prompt: str):
inputs = self.tokenizer(prompt, return_tensors="pt").to("cpu")
st.write("Input tensor:", inputs)
outputs = self.model.generate(
inputs.input_ids,
max_new_tokens=100,
do_sample=True,
top_k=5,
top_p=0.35,
temperature=0.2,
num_return_sequences=1,
)
st.write("Generated output:", outputs)
output_string = self.tokenizer.batch_decode(outputs)
st.write("Decoded output:", output_string)
return output_string
model = AppModel()
st.title("Welcome to the GPT Olympics generator")
prompt = st.text_area("Enter the beginning of your plot...")
clicked = st.button("Generate my movie")
if clicked:
st.write("Clicked!")
generated_plot = model.generate_plot(prompt)
st.write("Generated plot:", generated_plot)
if generated_plot:
st.write("Assistant:")
st.markdown(generated_plot[0])
else:
st.write("No plot generated.")
with open('./style.css') as f:
css = f.read()
st.markdown(f'<style>{css}</style>', unsafe_allow_html=True)
|