bruce1113 commited on
Commit
f969350
·
verified ·
1 Parent(s): 7546cf2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +77 -148
src/streamlit_app.py CHANGED
@@ -2,15 +2,16 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
- # -----------------------------
6
  # Model config
7
- # -----------------------------
8
- MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
 
 
9
 
10
  @st.cache_resource
11
- def load_qwen_model():
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
- # Use float16 on GPU if available, else float32 on CPU
14
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_NAME,
@@ -20,14 +21,13 @@ def load_qwen_model():
20
  model.to(device)
21
  return tokenizer, model, device
22
 
23
- tokenizer, model, device = load_qwen_model()
24
 
25
- # -----------------------------
26
- # Prompt / behavior config
27
- # -----------------------------
28
  SYSTEM_PROMPT = """
29
  You are taking part in a research study on how people read summaries.
30
-
31
  You will be given the transcript of an audio clip that a participant listened to.
32
  Your job is to write a single dense paragraph that summarizes the audio.
33
 
@@ -40,13 +40,40 @@ Follow these rules very carefully:
40
  4. Do NOT mark which details are incorrect, and do NOT mention that some facts are invented.
41
  5. Use clear, natural language and a neutral tone.
42
  6. Do NOT use bullet points, headings, or lists. Only one continuous paragraph.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
 
45
- def generate_summary_from_transcript(transcript_text: str) -> str:
46
  """
47
- Use Qwen to generate a dense, slightly hallucination-seeded summary
48
- from the given transcript.
 
49
  """
 
 
 
 
 
 
 
 
50
  messages = [
51
  {"role": "system", "content": SYSTEM_PROMPT},
52
  {
@@ -56,11 +83,12 @@ Here is the transcript of the audio the participant listened to:
56
 
57
  \"\"\"{transcript_text}\"\"\"
58
 
59
- Write the summary following the rules.
60
  """,
61
  },
62
  ]
63
 
 
64
  inputs = tokenizer.apply_chat_template(
65
  messages,
66
  add_generation_prompt=True,
@@ -72,152 +100,53 @@ Write the summary following the rules.
72
  with torch.no_grad():
73
  outputs = model.generate(
74
  **inputs,
75
- max_new_tokens=300,
76
  do_sample=True,
77
- temperature=0.9,
78
  top_p=0.95,
79
- repetition_penalty=1.05,
80
  )
81
 
82
- # Only decode the newly generated tokens after the prompt
83
  generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
84
- summary = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
85
- return summary
86
-
87
- # -----------------------------
88
- # Streamlit UI
89
- # -----------------------------
90
- st.set_page_config(page_title="LLM Study", layout="wide")
91
-
92
- st.title("Ask")
93
-
94
-
95
- # Sidebar
96
- with st.sidebar:
97
- st.header("Instructions (Researcher)")
98
- st.markdown(
99
- """
100
- 1. Have the participant listen to the audio **outside** this app (or in another tab).
101
- 2. Paste the **transcript** of that audio into the text box.
102
- 3. Click **Generate summary**.
103
- 4. Show the generated paragraph to the participant for reading / annotation / whatever your protocol specifies.
104
 
105
- You can also upload a `.txt` file containing the transcript.
106
- """
107
- )
108
- st.markdown("---")
109
- st.caption(f"Model: `{MODEL_NAME}`")
110
-
111
- # -----------------------------
112
- # Input area
113
- # -----------------------------
114
- col_left, col_right = st.columns([2, 1])
115
 
116
- with col_left:
117
- st.subheader("Transcript input")
118
 
119
- uploaded_file = st.file_uploader(
120
- "Optional: upload a .txt file with the transcript",
121
- type=["txt"],
122
- help="If provided, its content will be loaded into the transcript box below.",
123
- )
124
 
125
- # We store transcript in session_state to allow re-editing after upload
126
- if "transcript_text" not in st.session_state:
127
- st.session_state.transcript_text = ""
128
-
129
- if uploaded_file is not None:
130
- file_bytes = uploaded_file.read()
131
- try:
132
- st.session_state.transcript_text = file_bytes.decode("utf-8")
133
- except UnicodeDecodeError:
134
- st.warning("Could not decode file as UTF-8. Please check the file encoding.")
135
-
136
- transcript_text = st.text_area(
137
- "Transcript of the audio (paste or edit here):",
138
- value=st.session_state.transcript_text,
139
- height=300,
140
- )
141
- # Keep session_state in sync with edits
142
- st.session_state.transcript_text = transcript_text
143
-
144
- with col_right:
145
- st.subheader("Generation controls")
146
- max_new_tokens = st.slider("Max new tokens", 128, 512, 300, step=32)
147
- temperature = st.slider("Temperature (creativity)", 0.1, 1.5, 0.9, step=0.1)
148
- top_p = st.slider("Top-p (nucleus sampling)", 0.1, 1.0, 0.95, step=0.05)
149
-
150
- st.caption(
151
- """
152
- Higher temperature / top-p generally increases variation and may strengthen or increase
153
- the hallucinated details. For a controlled study, you might keep these fixed.
154
- """
155
- )
156
-
157
- # Re-bind hyperparameters into the generation function without changing its signature
158
- def generate_summary_from_transcript_with_params(transcript_text: str) -> str:
159
- messages = [
160
- {"role": "system", "content": SYSTEM_PROMPT},
161
  {
162
- "role": "user",
163
- "content": f"""
164
- Here is the transcript of the audio the participant listened to:
165
-
166
- \"\"\"{transcript_text}\"\"\"
167
-
168
- Write the summary following the rules.
169
- """,
170
- },
171
  ]
172
 
173
- inputs = tokenizer.apply_chat_template(
174
- messages,
175
- add_generation_prompt=True,
176
- tokenize=True,
177
- return_dict=True,
178
- return_tensors="pt",
179
- ).to(device)
180
 
181
- with torch.no_grad():
182
- outputs = model.generate(
183
- **inputs,
184
- max_new_tokens=int(max_new_tokens),
185
- do_sample=True,
186
- temperature=float(temperature),
187
- top_p=float(top_p),
188
- repetition_penalty=1.05,
189
- )
190
 
191
- generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
192
- summary = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
193
- return summary
194
-
195
- # -----------------------------
196
- # Generate button + output
197
- # -----------------------------
198
- st.markdown("---")
199
- generate_clicked = st.button("Generate summary")
200
-
201
- if generate_clicked:
202
- if not transcript_text.strip():
203
- st.warning("Please provide a transcript (paste text or upload a .txt file).")
204
- else:
205
- with st.spinner("Generating summary with Qwen2.5-3B-Instruct..."):
206
- summary = generate_summary_from_transcript_with_params(transcript_text)
207
-
208
- st.subheader("Model-generated summary (show this to participant):")
209
- st.write(summary)
210
-
211
- with st.expander("Show transcript (for researcher)"):
212
- st.text(transcript_text)
213
-
214
- with st.expander("Debug info (for researcher)"):
215
- st.json(
216
- {
217
- "max_new_tokens": max_new_tokens,
218
- "temperature": temperature,
219
- "top_p": top_p,
220
- "transcript_chars": len(transcript_text),
221
- }
222
- )
223
 
 
 
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # ------------------------------------------------
6
  # Model config
7
+ # ------------------------------------------------
8
+ # If it's too slow, you can change this to:
9
+ # "Qwen/Qwen2.5-0.5B-Instruct" (much faster)
10
+ MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
11
 
12
  @st.cache_resource
13
+ def load_model():
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
15
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_NAME,
 
21
  model.to(device)
22
  return tokenizer, model, device
23
 
24
+ tokenizer, model, device = load_model()
25
 
26
+ # ------------------------------------------------
27
+ # System prompt (for your user study)
28
+ # ------------------------------------------------
29
  SYSTEM_PROMPT = """
30
  You are taking part in a research study on how people read summaries.
 
31
  You will be given the transcript of an audio clip that a participant listened to.
32
  Your job is to write a single dense paragraph that summarizes the audio.
33
 
 
40
  4. Do NOT mark which details are incorrect, and do NOT mention that some facts are invented.
41
  5. Use clear, natural language and a neutral tone.
42
  6. Do NOT use bullet points, headings, or lists. Only one continuous paragraph.
43
+
44
+ The user will usually paste the transcript of the audio as their message.
45
+ Just respond with the summary paragraph.
46
+ """
47
+
48
+ transcript_text =
49
+ """
50
+ Virginia Du Fray was one of the most prominent accusers of notorious US sex offender Jeffrey Epstein.
51
+ I've been fighting that very world to this day and I won't stop fighting. The 41-year-old died by
52
+ suicide at her property north of Perth in April this year, leaving behind a significant
53
+ estate, but no valid will. Now, a legal stash is underway in Perth's Supreme Court over access to
54
+ Mr. Jafrey's estate, which is thought to be worth millions. The court has appointed an interim administrator to oversee
55
+ the estate after Ms. Jafrey's teenage sons applied to be the administrators, prompting a counter suit launched by Ms. Jup Fray's lawyer,
56
+ Carrie Lden, and her former friend and carer, Cheryl Meyers. The court today heard that their counter claim, if successful, would see M.
57
+ Jupy's aranged husband, Robert, removed from his entitlements to her estate.Once copies of those pleadings are provided to the
58
+ media uh then you will be able to establish the basis for that counter claim. WA Supreme Court registar Danielle Davies told
59
+ the court Jafrey's former husband and her young daughter should be added to the proceedings. The case is expected to resume
60
+ in the new year. Rian Shine, ABC News.
61
  """
62
 
63
+ def chat_with_qwen(chat_history):
64
  """
65
+ chat_history: list of {"role": "user"/"assistant", "content": str}
66
+ We treat the **last user message** as the transcript text.
67
+ Returns the assistant's reply string (the summary paragraph).
68
  """
69
+
70
+ # 1) Get the most recent user message = transcript text
71
+ for msg in reversed(chat_history):
72
+ if msg["role"] == "user":
73
+ transcript_text = msg["content"]
74
+ break
75
+
76
+ # 2) Build messages: system prompt + one user turn containing the transcript
77
  messages = [
78
  {"role": "system", "content": SYSTEM_PROMPT},
79
  {
 
83
 
84
  \"\"\"{transcript_text}\"\"\"
85
 
86
+ Write the summary following the rules in the system prompt.
87
  """,
88
  },
89
  ]
90
 
91
+ # 3) Apply Qwen chat template and generate
92
  inputs = tokenizer.apply_chat_template(
93
  messages,
94
  add_generation_prompt=True,
 
100
  with torch.no_grad():
101
  outputs = model.generate(
102
  **inputs,
103
+ max_new_tokens=200, # one dense paragraph
104
  do_sample=True,
105
+ temperature=0.8,
106
  top_p=0.95,
 
107
  )
108
 
109
+ # 4) Decode only the new tokens after the prompt
110
  generated_ids = outputs[0][inputs["input_ids"].shape[1]:]
111
+ reply = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
112
+ return reply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ # ------------------------------------------------
115
+ # Simple chat UI (original style)
116
+ # ------------------------------------------------
117
+ st.set_page_config(page_title="Simple LLM", page_icon="💬")
 
 
 
 
 
 
118
 
119
+ st.title("💬 Simple LLM")
 
120
 
 
 
 
 
 
121
 
122
+ # Initialize chat history
123
+ if "messages" not in st.session_state:
124
+ st.session_state["messages"] = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  {
126
+ "role": "assistant",
127
+ "content": "Hi!"
128
+ }
 
 
 
 
 
 
129
  ]
130
 
131
+ # Display chat history
132
+ for msg in st.session_state["messages"]:
133
+ with st.chat_message(msg["role"]):
134
+ st.markdown(msg["content"])
 
 
 
135
 
136
+ # Chat input (simple, original format)
137
+ user_input = st.chat_input("Paste transcript or ask something...")
 
 
 
 
 
 
 
138
 
139
+ if user_input:
140
+ # Add user message
141
+ st.session_state["messages"].append({"role": "user", "content": user_input})
142
+ with st.chat_message("user"):
143
+ st.markdown(user_input)
144
+
145
+ # Generate model reply
146
+ with st.chat_message("assistant"):
147
+ with st.spinner("Thinking..."):
148
+ reply = chat_with_qwen(st.session_state["messages"])
149
+ st.markdown(reply)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ # Save reply to history
152
+ st.session_state["messages"].append({"role": "assistant", "content": reply})