Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # ----------------------------- | |
| # Page Configuration | |
| # ----------------------------- | |
| st.set_page_config( | |
| page_title="AI Text Generator", | |
| page_icon="🤖", | |
| layout="wide" | |
| ) | |
| # ----------------------------- | |
| # Device Setup (HF Spaces safe) | |
| # ----------------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ----------------------------- | |
| # Sidebar | |
| # ----------------------------- | |
| st.sidebar.title("⚙️ Settings") | |
| model_path = st.sidebar.text_input( | |
| "Model Name / Path", | |
| value="gpt2" | |
| ) | |
| max_new_tokens = st.sidebar.slider("Max New Tokens", 20, 300, 100) | |
| temperature = st.sidebar.slider("Temperature", 0.5, 1.5, 0.8) | |
| top_k = st.sidebar.slider("Top-K", 10, 100, 50) | |
| top_p = st.sidebar.slider("Top-P", 0.5, 1.0, 0.95) | |
| st.sidebar.write(f"Device: **{device.upper()}**") | |
| # ----------------------------- | |
| # Title | |
| # ----------------------------- | |
| st.title("🤖 Professional AI Text Generator") | |
| st.markdown("Generate text using Hugging Face models.") | |
| # ----------------------------- | |
| # Load Model (cached) | |
| # ----------------------------- | |
| def load_model(model_name): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float32 # safer for CPU Spaces | |
| ) | |
| model.to(device) | |
| model.eval() | |
| return tokenizer, model | |
| # Load model safely | |
| try: | |
| tokenizer, model = load_model(model_path) | |
| except Exception as e: | |
| st.error(f"Model loading failed: {e}") | |
| st.stop() | |
| # ----------------------------- | |
| # Input Area | |
| # ----------------------------- | |
| prompt = st.text_area( | |
| "Enter your prompt:", | |
| height=200, | |
| placeholder="Example: Once upon a time..." | |
| ) | |
| # ----------------------------- | |
| # Generate Button | |
| # ----------------------------- | |
| if st.button("✨ Generate Text", use_container_width=True): | |
| if prompt.strip() == "": | |
| st.warning("Please enter a prompt.") | |
| else: | |
| with st.spinner("Generating..."): | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated_text = tokenizer.decode( | |
| output[0], | |
| skip_special_tokens=True | |
| ) | |
| st.subheader("Generated Output") | |
| st.write(generated_text) | |
| st.download_button( | |
| label="📥 Download", | |
| data=generated_text, | |
| file_name="generated_text.txt", | |
| mime="text/plain" | |
| ) | |
| # ----------------------------- | |
| # Footer | |
| # ----------------------------- | |
| st.markdown("---") | |
| st.markdown("Built with ❤️ using Streamlit + Transformers") | |