import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForCausalLM # ----------------------------- # Page Configuration # ----------------------------- st.set_page_config( page_title="AI Text Generator", page_icon="🤖", layout="wide" ) # ----------------------------- # Device Setup (HF Spaces safe) # ----------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" # ----------------------------- # Sidebar # ----------------------------- st.sidebar.title("⚙️ Settings") model_path = st.sidebar.text_input( "Model Name / Path", value="gpt2" ) max_new_tokens = st.sidebar.slider("Max New Tokens", 20, 300, 100) temperature = st.sidebar.slider("Temperature", 0.5, 1.5, 0.8) top_k = st.sidebar.slider("Top-K", 10, 100, 50) top_p = st.sidebar.slider("Top-P", 0.5, 1.0, 0.95) st.sidebar.write(f"Device: **{device.upper()}**") # ----------------------------- # Title # ----------------------------- st.title("🤖 Professional AI Text Generator") st.markdown("Generate text using Hugging Face models.") # ----------------------------- # Load Model (cached) # ----------------------------- @st.cache_resource def load_model(model_name): tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32 # safer for CPU Spaces ) model.to(device) model.eval() return tokenizer, model # Load model safely try: tokenizer, model = load_model(model_path) except Exception as e: st.error(f"Model loading failed: {e}") st.stop() # ----------------------------- # Input Area # ----------------------------- prompt = st.text_area( "Enter your prompt:", height=200, placeholder="Example: Once upon a time..." ) # ----------------------------- # Generate Button # ----------------------------- if st.button("✨ Generate Text", use_container_width=True): if prompt.strip() == "": st.warning("Please enter a prompt.") else: with st.spinner("Generating..."): inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode( output[0], skip_special_tokens=True ) st.subheader("Generated Output") st.write(generated_text) st.download_button( label="📥 Download", data=generated_text, file_name="generated_text.txt", mime="text/plain" ) # ----------------------------- # Footer # ----------------------------- st.markdown("---") st.markdown("Built with ❤️ using Streamlit + Transformers")