genbot / app.py
MLCraftsman's picture
Upload 2 files
f0f95ab verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
st.set_page_config(page_title="AI Text Generator", page_icon="🤖", layout="wide")
st.title("🤖 AI Text Generator")
# Sidebar
st.sidebar.title("Settings")
model_name = st.sidebar.text_input("Model", value="gpt2")
max_new_tokens = st.sidebar.slider("Max New Tokens", 20, 200, 100)
temperature = st.sidebar.slider("Temperature", 0.5, 1.5, 0.8)
device = "cuda" if torch.cuda.is_available() else "cpu"
st.sidebar.write(f"Device: {device}")
# Load model safely
@st.cache_resource
def load_model(name):
tokenizer = AutoTokenizer.from_pretrained(name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
name,
torch_dtype=torch.float32 # safer for CPU
)
model.to(device)
model.eval()
return tokenizer, model
try:
tokenizer, model = load_model(model_name)
except Exception as e:
st.error(f"Error loading model: {e}")
st.stop()
# Input
prompt = st.text_area("Enter your prompt")
# Generate
if st.button("Generate"):
if prompt.strip() == "":
st.warning("Enter a prompt")
else:
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
st.subheader("Output")
st.write(text)