HMC83 commited on
Commit
5077d04
·
verified ·
1 Parent(s): 9762fa4

Update app.py

Browse files

Ditching validation for Wihtgar

Files changed (1) hide show
  1. app.py +27 -80
app.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  import random
4
  import time
5
  import torch
6
- import re # <-- NEW
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  import spaces
9
 
@@ -25,7 +24,6 @@ except Exception as e:
25
  tokenizer = None
26
 
27
  # --- Data for the Reels ---
28
- # A list of authority and keyword combinations.
29
  FOI_COMBINATIONS = [
30
  {"authority": "Borders NHS Board", "keywords": "whistleblowing guidance, wrongdoing, public body"},
31
  {"authority": "Borders NHS Board", "keywords": "ethical support, clinical triage, minutes"},
@@ -176,27 +174,9 @@ FOI_COMBINATIONS = [
176
  {"authority": "Lancaster City Council", "keywords": "coastal erosion, protection measures, maintenance spending"},
177
  ]
178
 
179
- # Create lists for the spinning animation from the combinations above
180
  ALL_AUTHORITIES_FOR_SPIN = list(set([item["authority"] for item in FOI_COMBINATIONS]))
181
  ALL_KEYWORDS_FOR_SPIN = list(set(kw.strip() for item in FOI_COMBINATIONS for kw in item["keywords"].split(',')))
182
 
183
- # --- Helper: clean model output into a numbered list starting at "1." ---
184
- def clean_and_validate_output(text: str):
185
- """
186
- Extract the main numbered list starting at '1.' and strip any closing signature lines.
187
- Always returns cleaned text and a boolean flag (True = looks fine).
188
- """
189
- # Keep everything from the first "1." onward, if present.
190
- m = re.search(r'(?m)^\s*1\.\s', text)
191
- body = text[m.start():].strip() if m else text.strip()
192
-
193
- # Remove common signature lines at the end (best-effort).
194
- body = re.sub(r'(?im)^\s*(yours.*|kind regards.*|regards.*)$', '', body).strip()
195
-
196
- # If it doesn't contain at least one numbered point, it's still usable, but we mark as not strictly-valid.
197
- is_valid = bool(re.search(r'(?m)^\s*\d+\.\s', body))
198
- return body, is_valid
199
-
200
  # --- Helper: wrap content in the FOI letter template ---
201
  def wrap_in_letter(authority: str, body: str) -> str:
202
  body = body.strip()
@@ -211,7 +191,6 @@ def wrap_in_letter(authority: str, body: str) -> str:
211
  # --- Backend Function for Local Inference ---
212
  @spaces.GPU
213
  def generate_request_local(authority, kw1, kw2, kw3):
214
- """Generates a request using the locally loaded transformer model, with validation and retry logic."""
215
  if not model or not tokenizer:
216
  return "Error: Model is not loaded. Please check the Space logs for details."
217
 
@@ -219,70 +198,43 @@ def generate_request_local(authority, kw1, kw2, kw3):
219
  keyword_string = ", ".join(keywords)
220
  prompt = (
221
  "You are an expert at writing formal Freedom of Information requests to UK public authorities. "
222
- f"Generate ONLY the numbered list of the specific information being requested, starting at '1.' "
223
- f"for {authority}, using these keywords: {keyword_string}. "
224
- "Do not include greetings or signatures."
225
  )
226
 
227
- max_retries = 2
228
- for attempt in range(max_retries):
229
- try:
230
- # Tokenize the input prompt
231
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
232
-
233
- # Set generation parameters
234
- generation_params = {
235
- "max_new_tokens": 250,
236
- "do_sample": True,
237
- "temperature": 0.25,
238
- "top_k": 50,
239
- "top_p": 0.95,
240
- "repetition_penalty": 1.1,
241
- "streamer": None,
242
- "eos_token_id": tokenizer.eos_token_id
243
- }
244
-
245
- # Generate text sequences
246
- output_sequences = model.generate(**inputs, **generation_params)
247
-
248
- # Decode the generated text
249
- generated_text = tokenizer.decode(
250
- output_sequences[0][len(inputs["input_ids"][0]):],
251
- skip_special_tokens=True
252
- ).strip()
253
 
254
- # Remove artifact if present
255
- if generated_text.startswith('.\n'):
256
- generated_text = generated_text[2:]
 
 
 
 
 
 
257
 
258
- # Clean and validate the output
259
- cleaned_text, is_valid = clean_and_validate_output(generated_text)
260
 
261
- # Wrap in the letter template regardless; validation just influences retry behavior
262
- letter = wrap_in_letter(authority, cleaned_text)
 
 
263
 
264
- if is_valid:
265
- return letter
266
- else:
267
- print(f"Attempt {attempt + 1}/{max_retries}: Output lacked clear numbering. Retrying...")
268
 
269
- except Exception as e:
270
- print(f"Error during generation attempt {attempt + 1}/{max_retries}: {e}")
271
- if attempt == max_retries - 1:
272
- return f"An error occurred during text generation: {e}"
273
 
274
- # If retries failed, return the best effort letter using the last cleaned text we had
275
- return wrap_in_letter(authority, "1. [Unable to format automatically] Please restate the information requested.\n2. [Optional second point]")
276
 
277
  # --- Gradio UI and Spinning Logic ---
278
  def spin_the_reels():
279
- """A generator function that simulates spinning reels and then calls the model."""
280
- # 1. Simulate the spinning effect
281
- spin_duration = 2.0 # seconds
282
- spin_interval = 0.05 # update interval
283
  start_time = time.time()
284
  while time.time() - start_time < spin_duration:
285
- # Yield random values for each reel to create the spinning illusion
286
  yield (
287
  random.choice(ALL_AUTHORITIES_FOR_SPIN),
288
  random.choice(ALL_KEYWORDS_FOR_SPIN),
@@ -292,22 +244,18 @@ def spin_the_reels():
292
  )
293
  time.sleep(spin_interval)
294
 
295
- # 2. Select the final fixed combination
296
  final_combination = random.choice(FOI_COMBINATIONS)
297
  final_authority = final_combination["authority"]
298
 
299
- # Split, strip, and pad keywords to ensure we always have 3 for the UI
300
  keywords_list = [k.strip() for k in final_combination["keywords"].split(',')]
301
- keywords_list += [''] * (3 - len(keywords_list)) # Pad with empty strings if < 3
302
- kw1, kw2, kw3 = keywords_list[:3] # Take the first 3
303
 
304
- # Display the final reel values and a "Generating..." message
305
  yield (
306
  final_authority, kw1, kw2, kw3,
307
  f"Generating request for {final_authority}...\nPlease wait, this may take a moment."
308
  )
309
 
310
- # 3. Call the local model and yield the final result
311
  generated_request = generate_request_local(final_authority, kw1, kw2, kw3)
312
  yield (
313
  final_authority, kw1, kw2, kw3,
@@ -315,7 +263,6 @@ def spin_the_reels():
315
  )
316
 
317
  # --- CSS for Styling ---
318
- # Added min-width to reduce UI flickering on text change
319
  reels_css = """
320
  #reels-container {
321
  display: flex;
@@ -328,7 +275,7 @@ reels_css = """
328
  border-radius: 12px;
329
  background-color: #fef3c7;
330
  box-shadow: 0 4px 6px rgba(0,0,0,0.1);
331
- min-width: 150px; /* Prevents resizing/flickering during spin */
332
  }
333
  #reels-container .gradio-textbox input {
334
  font-size: 1.25rem !important;
 
3
  import random
4
  import time
5
  import torch
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import spaces
8
 
 
24
  tokenizer = None
25
 
26
  # --- Data for the Reels ---
 
27
  FOI_COMBINATIONS = [
28
  {"authority": "Borders NHS Board", "keywords": "whistleblowing guidance, wrongdoing, public body"},
29
  {"authority": "Borders NHS Board", "keywords": "ethical support, clinical triage, minutes"},
 
174
  {"authority": "Lancaster City Council", "keywords": "coastal erosion, protection measures, maintenance spending"},
175
  ]
176
 
 
177
  ALL_AUTHORITIES_FOR_SPIN = list(set([item["authority"] for item in FOI_COMBINATIONS]))
178
  ALL_KEYWORDS_FOR_SPIN = list(set(kw.strip() for item in FOI_COMBINATIONS for kw in item["keywords"].split(',')))
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  # --- Helper: wrap content in the FOI letter template ---
181
  def wrap_in_letter(authority: str, body: str) -> str:
182
  body = body.strip()
 
191
  # --- Backend Function for Local Inference ---
192
  @spaces.GPU
193
  def generate_request_local(authority, kw1, kw2, kw3):
 
194
  if not model or not tokenizer:
195
  return "Error: Model is not loaded. Please check the Space logs for details."
196
 
 
198
  keyword_string = ", ".join(keywords)
199
  prompt = (
200
  "You are an expert at writing formal Freedom of Information requests to UK public authorities. "
201
+ f"Generate the request text (without greeting or signature) for {authority}, using these keywords: {keyword_string}."
 
 
202
  )
203
 
204
+ try:
205
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ generation_params = {
208
+ "max_new_tokens": 250,
209
+ "do_sample": True,
210
+ "temperature": 0.25,
211
+ "top_k": 50,
212
+ "top_p": 0.95,
213
+ "repetition_penalty": 1.1,
214
+ "eos_token_id": tokenizer.eos_token_id
215
+ }
216
 
217
+ output_sequences = model.generate(**inputs, **generation_params)
 
218
 
219
+ generated_text = tokenizer.decode(
220
+ output_sequences[0][len(inputs["input_ids"][0]):],
221
+ skip_special_tokens=True
222
+ ).strip()
223
 
224
+ if generated_text.startswith('.\n'):
225
+ generated_text = generated_text[2:]
 
 
226
 
227
+ return wrap_in_letter(authority, generated_text)
 
 
 
228
 
229
+ except Exception as e:
230
+ return f"An error occurred during text generation: {e}"
231
 
232
  # --- Gradio UI and Spinning Logic ---
233
  def spin_the_reels():
234
+ spin_duration = 2.0
235
+ spin_interval = 0.05
 
 
236
  start_time = time.time()
237
  while time.time() - start_time < spin_duration:
 
238
  yield (
239
  random.choice(ALL_AUTHORITIES_FOR_SPIN),
240
  random.choice(ALL_KEYWORDS_FOR_SPIN),
 
244
  )
245
  time.sleep(spin_interval)
246
 
 
247
  final_combination = random.choice(FOI_COMBINATIONS)
248
  final_authority = final_combination["authority"]
249
 
 
250
  keywords_list = [k.strip() for k in final_combination["keywords"].split(',')]
251
+ keywords_list += [''] * (3 - len(keywords_list))
252
+ kw1, kw2, kw3 = keywords_list[:3]
253
 
 
254
  yield (
255
  final_authority, kw1, kw2, kw3,
256
  f"Generating request for {final_authority}...\nPlease wait, this may take a moment."
257
  )
258
 
 
259
  generated_request = generate_request_local(final_authority, kw1, kw2, kw3)
260
  yield (
261
  final_authority, kw1, kw2, kw3,
 
263
  )
264
 
265
  # --- CSS for Styling ---
 
266
  reels_css = """
267
  #reels-container {
268
  display: flex;
 
275
  border-radius: 12px;
276
  background-color: #fef3c7;
277
  box-shadow: 0 4px 6px rgba(0,0,0,0.1);
278
+ min-width: 150px;
279
  }
280
  #reels-container .gradio-textbox input {
281
  font-size: 1.25rem !important;