TextGeneration / app.py
MLCraftsman's picture
Upload 3 files
de1dba4 verified
raw
history blame
2.42 kB
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Page Config
st.set_page_config(
page_title="AI Text Generator",
page_icon="🤖",
layout="wide"
)
# Sidebar
st.sidebar.title("⚙️ Settings")
model_path = st.sidebar.text_input(
"Model Path",
value="gpt2" # change to ./results if fine-tuned
)
max_length = st.sidebar.slider("Max Length", 50, 500, 150)
temperature = st.sidebar.slider("Temperature (Creativity)", 0.5, 1.5, 0.8)
top_k = st.sidebar.slider("Top-K", 10, 100, 50)
top_p = st.sidebar.slider("Top-P", 0.5, 1.0, 0.95)
device = "cuda" if torch.cuda.is_available() else "cpu"
st.sidebar.write(f"Device: **{device.upper()}**")
# Title
st.title("🤖 Professional AI Text Generator")
st.markdown("Generate creative and grammatically correct text using a GPT-based model.")
# Load Model (cached)
@st.cache_resource
def load_model(path):
tokenizer = AutoTokenizer.from_pretrained(path)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(path)
model.to(device)
model.eval()
return tokenizer, model
tokenizer, model = load_model(model_path)
# Input Area
col1, col2 = st.columns([2, 1])
with col1:
prompt = st.text_area(
"Enter your prompt:",
height=200,
placeholder="Example: Alice was walking through the forest when..."
)
with col2:
st.info("Tips:\n- Higher temperature = more creative\n- Lower temperature = more accurate\n- Use your fine-tuned model for best results")
# Generate Button
if st.button("✨ Generate Text", use_container_width=True):
if prompt.strip() == "":
st.warning("Please enter a prompt.")
else:
with st.spinner("Generating..."):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
```
output = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
st.subheader("Generated Output")
st.write(generated_text)
# Download option
st.download_button(
label="📥 Download Text",
data=generated_text,
file_name="generated_text.txt",
mime="text/plain"
)
```
# Footer
st.markdown("---")
st.markdown("Built with ❤️ using Streamlit + Transformers")