TextGeneration / app.py
MLCraftsman's picture
Update app.py
2bc2dcf verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# -----------------------------
# Page Configuration
# -----------------------------
st.set_page_config(
page_title="AI Text Generator",
page_icon="🤖",
layout="wide"
)
# -----------------------------
# Device Setup (HF Spaces safe)
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# -----------------------------
# Sidebar
# -----------------------------
st.sidebar.title("⚙️ Settings")
model_path = st.sidebar.text_input(
"Model Name / Path",
value="gpt2"
)
max_new_tokens = st.sidebar.slider("Max New Tokens", 20, 300, 100)
temperature = st.sidebar.slider("Temperature", 0.5, 1.5, 0.8)
top_k = st.sidebar.slider("Top-K", 10, 100, 50)
top_p = st.sidebar.slider("Top-P", 0.5, 1.0, 0.95)
st.sidebar.write(f"Device: **{device.upper()}**")
# -----------------------------
# Title
# -----------------------------
st.title("🤖 Professional AI Text Generator")
st.markdown("Generate text using Hugging Face models.")
# -----------------------------
# Load Model (cached)
# -----------------------------
@st.cache_resource
def load_model(model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32 # safer for CPU Spaces
)
model.to(device)
model.eval()
return tokenizer, model
# Load model safely
try:
tokenizer, model = load_model(model_path)
except Exception as e:
st.error(f"Model loading failed: {e}")
st.stop()
# -----------------------------
# Input Area
# -----------------------------
prompt = st.text_area(
"Enter your prompt:",
height=200,
placeholder="Example: Once upon a time..."
)
# -----------------------------
# Generate Button
# -----------------------------
if st.button("✨ Generate Text", use_container_width=True):
if prompt.strip() == "":
st.warning("Please enter a prompt.")
else:
with st.spinner("Generating..."):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(
output[0],
skip_special_tokens=True
)
st.subheader("Generated Output")
st.write(generated_text)
st.download_button(
label="📥 Download",
data=generated_text,
file_name="generated_text.txt",
mime="text/plain"
)
# -----------------------------
# Footer
# -----------------------------
st.markdown("---")
st.markdown("Built with ❤️ using Streamlit + Transformers")