NetraVerse commited on
Commit
99e51c6
Β·
verified Β·
1 Parent(s): 8d83a99

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +108 -40
src/streamlit_app.py CHANGED
@@ -1,40 +1,108 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "βœ… Model {model_name} loaded successfully on {DEVICE_STR}!")
2
+ return tokenizer, model
3
+
4
+ except ValueError as e:
5
+ if "Unrecognized configuration class" in str(e):
6
+ progress_placeholder.error(f"❌ Error: {model_name} is not a causal language model suitable for text generation. Please select a different model.")
7
+ st.error(f"Technical details: {str(e)}")
8
+ else:
9
+ progress_placeholder.error(f"❌ Error loading model: {str(e)}")
10
+ raise e
11
+ except Exception as e:
12
+ progress_placeholder.error(f"❌ Unexpected error loading model: {str(e)}")
13
+ raise e
14
+
15
+ tokenizer, model = load_model(MODEL_NAME)
16
+
17
+ def generate_text(prompt, max_new_tokens=150, temperature=0.7, top_p=0.9):
18
+ inputs = tokenizer(prompt, return_tensors="pt")
19
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
20
+ with torch.no_grad():
21
+ outputs = model.generate(
22
+ **inputs,
23
+ max_new_tokens=max_new_tokens,
24
+ do_sample=True,
25
+ temperature=temperature,
26
+ top_p=top_p,
27
+ pad_token_id=tokenizer.eos_token_id,
28
+ eos_token_id=tokenizer.eos_token_id,
29
+ )
30
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+ return text
32
+
33
+ # ---------- Streamlit UI ----------
34
+ st.title(f"Language Model Text Generator ({DEVICE_STR.upper()})")
35
+ st.caption("Choose from various pre-trained language models for text generation")
36
+
37
+ prompt = st.text_area(
38
+ "Enter prompt (English or other supported languages depending on model)",
39
+ value="The future of artificial intelligence is",
40
+ height=150,
41
+ )
42
+
43
+ max_new_tokens = st.slider("Max output tokens", 32, 512, 150)
44
+ temperature = st.slider("Temperature", 0.1, 1.2, 0.7)
45
+ top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.9)
46
+
47
+ if st.button("Generate"):
48
+ # Create progress placeholder
49
+ progress_container = st.container()
50
+
51
+ with progress_container:
52
+ progress_bar = st.progress(0)
53
+ status_text = st.empty()
54
+
55
+ try:
56
+ status_text.text("πŸ”„ Preparing input...")
57
+ progress_bar.progress(25)
58
+
59
+ status_text.text("πŸ€– Generating text... (this may take 20-40s on CPU)")
60
+ progress_bar.progress(50)
61
+
62
+ output = generate_text(prompt, max_new_tokens, temperature, top_p)
63
+
64
+ progress_bar.progress(100)
65
+ status_text.text("βœ… Generation complete!")
66
+
67
+ # Clear progress indicators after a short delay
68
+ import time
69
+ time.sleep(1)
70
+ progress_bar.empty()
71
+ status_text.empty()
72
+
73
+ st.subheader("Model output:")
74
+ st.write(output)
75
+
76
+ except Exception as e:
77
+ progress_bar.empty()
78
+ status_text.empty()
79
+ st.error(f"❌ Generation failed: {e}")
80
+
81
+ st.markdown("---")
82
+
83
+ # Model Status Section
84
+ st.subheader("πŸ“Š Model Status")
85
+ col1, col2, col3 = st.columns(3)
86
+
87
+ with col1:
88
+ st.metric("Current Model", MODEL_NAME)
89
+ with col2:
90
+ st.metric("Device", DEVICE_STR.upper())
91
+ with col3:
92
+ # Check if model is loaded by trying to access it
93
+ try:
94
+ model_params = sum(p.numel() for p in model.parameters())
95
+ st.metric("Model Parameters", f"{model_params:,}")
96
+ except:
97
+ st.metric("Model Parameters", "Loading...")
98
+
99
+ st.markdown("---")
100
+ st.markdown(
101
+ """
102
+ **Tips**
103
+ - First run will download model to `~/.cache/huggingface`.
104
+ - DialoGPT models work well for conversational text.
105
+ - GPT-2/DistilGPT-2 work best with English prompts.
106
+ - Use smaller models (DialoGPT-small, DistilGPT-2) for faster CPU response.
107
+ """
108
+ )