ai_plugnplay_models / src /streamlit_app.py
NetraVerse's picture
Update src/streamlit_app.py
99e51c6 verified
"βœ… Model {model_name} loaded successfully on {DEVICE_STR}!")
return tokenizer, model
except ValueError as e:
if "Unrecognized configuration class" in str(e):
progress_placeholder.error(f"❌ Error: {model_name} is not a causal language model suitable for text generation. Please select a different model.")
st.error(f"Technical details: {str(e)}")
else:
progress_placeholder.error(f"❌ Error loading model: {str(e)}")
raise e
except Exception as e:
progress_placeholder.error(f"❌ Unexpected error loading model: {str(e)}")
raise e
tokenizer, model = load_model(MODEL_NAME)
def generate_text(prompt, max_new_tokens=150, temperature=0.7, top_p=0.9):
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text
# ---------- Streamlit UI ----------
st.title(f"Language Model Text Generator ({DEVICE_STR.upper()})")
st.caption("Choose from various pre-trained language models for text generation")
prompt = st.text_area(
"Enter prompt (English or other supported languages depending on model)",
value="The future of artificial intelligence is",
height=150,
)
max_new_tokens = st.slider("Max output tokens", 32, 512, 150)
temperature = st.slider("Temperature", 0.1, 1.2, 0.7)
top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.9)
if st.button("Generate"):
# Create progress placeholder
progress_container = st.container()
with progress_container:
progress_bar = st.progress(0)
status_text = st.empty()
try:
status_text.text("πŸ”„ Preparing input...")
progress_bar.progress(25)
status_text.text("πŸ€– Generating text... (this may take 20-40s on CPU)")
progress_bar.progress(50)
output = generate_text(prompt, max_new_tokens, temperature, top_p)
progress_bar.progress(100)
status_text.text("βœ… Generation complete!")
# Clear progress indicators after a short delay
import time
time.sleep(1)
progress_bar.empty()
status_text.empty()
st.subheader("Model output:")
st.write(output)
except Exception as e:
progress_bar.empty()
status_text.empty()
st.error(f"❌ Generation failed: {e}")
st.markdown("---")
# Model Status Section
st.subheader("πŸ“Š Model Status")
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Current Model", MODEL_NAME)
with col2:
st.metric("Device", DEVICE_STR.upper())
with col3:
# Check if model is loaded by trying to access it
try:
model_params = sum(p.numel() for p in model.parameters())
st.metric("Model Parameters", f"{model_params:,}")
except:
st.metric("Model Parameters", "Loading...")
st.markdown("---")
st.markdown(
"""
**Tips**
- First run will download model to `~/.cache/huggingface`.
- DialoGPT models work well for conversational text.
- GPT-2/DistilGPT-2 work best with English prompts.
- Use smaller models (DialoGPT-small, DistilGPT-2) for faster CPU response.
"""
)