HMC83 commited on
Commit
dbe7eb9
·
verified ·
1 Parent(s): c10fcc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -52
app.py CHANGED
@@ -3,10 +3,11 @@ import os
3
  import random
4
  import time
5
  import torch
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import spaces
8
 
9
- MODEL_ID = "HMC83/Wihtgar-650M-DPO-Requests-2"
10
 
11
  # --- Load Model and Tokenizer ---
12
  print("Loading model and tokenizer...")
@@ -179,46 +180,33 @@ FOI_COMBINATIONS = [
179
  ALL_AUTHORITIES_FOR_SPIN = list(set([item["authority"] for item in FOI_COMBINATIONS]))
180
  ALL_KEYWORDS_FOR_SPIN = list(set(kw.strip() for item in FOI_COMBINATIONS for kw in item["keywords"].split(',')))
181
 
182
-
183
- # --- Helper Function for Cleaning and Validation ---
184
- def clean_and_validate_output(raw_text: str) -> tuple[str, bool]:
185
  """
186
- Cleans the model's output by keeping only the first complete request.
187
-
188
- It validates that the output contains essential markers ("Dear" and "[Your Name]").
189
- If it detects that the model has started generating a second request, it truncates
190
- the string after the first "[Your Name]".
191
-
192
- Args:
193
- raw_text: The raw string output from the language model.
194
-
195
- Returns:
196
- A tuple containing:
197
- - The cleaned text.
198
- - A boolean flag: True if the output is valid, False if it is malformed.
199
  """
200
- end_marker = "[Your Name]"
201
- start_marker = "Dear"
 
202
 
203
- # Validate: A valid request must contain the end marker.
204
- if end_marker not in raw_text:
205
- return raw_text, False # Malformed, signal for regeneration.
206
 
207
- # Find the end of the first complete request.
208
- first_end_pos = raw_text.find(end_marker)
209
- end_of_first_request_index = first_end_pos + len(end_marker)
210
-
211
- # Check if a second request has started after the first one ended.
212
- start_of_second_request_pos = raw_text.find(start_marker, end_of_first_request_index)
213
-
214
- if start_of_second_request_pos != -1:
215
- # If a second request is found, truncate to keep only the first one.
216
- cleaned_text = raw_text[:end_of_first_request_index]
217
- return cleaned_text, True
218
- else:
219
- # No second request found, the output is valid.
220
- return raw_text, True
221
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  # --- Backend Function for Local Inference ---
224
  @spaces.GPU
@@ -231,7 +219,9 @@ def generate_request_local(authority, kw1, kw2, kw3):
231
  keyword_string = ", ".join(keywords)
232
  prompt = (
233
  "You are an expert at writing formal Freedom of Information requests to UK public authorities. "
234
- f"""Generate a formal Freedom of Information request to {authority} using these keywords: {keyword_string}"""
 
 
235
  )
236
 
237
  max_retries = 2
@@ -243,7 +233,7 @@ def generate_request_local(authority, kw1, kw2, kw3):
243
  # Set generation parameters
244
  generation_params = {
245
  "max_new_tokens": 340,
246
- "temperature": 0.3,
247
  "top_p": 0.95,
248
  "top_k": 50,
249
  "repetition_penalty": 1.1,
@@ -264,22 +254,24 @@ def generate_request_local(authority, kw1, kw2, kw3):
264
  if generated_text.startswith('.\n'):
265
  generated_text = generated_text[2:]
266
 
267
- # **NEW**: Clean and validate the output
268
  cleaned_text, is_valid = clean_and_validate_output(generated_text)
269
-
 
 
 
270
  if is_valid:
271
- return cleaned_text # Success! Return the valid, cleaned text.
272
  else:
273
- print(f"Attempt {attempt + 1}/{max_retries}: Malformed output detected. Retrying...")
274
 
275
  except Exception as e:
276
  print(f"Error during generation attempt {attempt + 1}/{max_retries}: {e}")
277
  if attempt == max_retries - 1:
278
  return f"An error occurred during text generation: {e}"
279
 
280
- # If the loop finishes, all retries have failed
281
- return "Failed to generate a valid request after multiple attempts. Please try again."
282
-
283
 
284
  # --- Gradio UI and Spinning Logic ---
285
  def spin_the_reels():
@@ -298,22 +290,22 @@ def spin_the_reels():
298
  "Spinning..."
299
  )
300
  time.sleep(spin_interval)
301
-
302
  # 2. Select the final fixed combination
303
  final_combination = random.choice(FOI_COMBINATIONS)
304
  final_authority = final_combination["authority"]
305
-
306
  # Split, strip, and pad keywords to ensure we always have 3 for the UI
307
  keywords_list = [k.strip() for k in final_combination["keywords"].split(',')]
308
  keywords_list += [''] * (3 - len(keywords_list)) # Pad with empty strings if < 3
309
  kw1, kw2, kw3 = keywords_list[:3] # Take the first 3
310
-
311
  # Display the final reel values and a "Generating..." message
312
  yield (
313
  final_authority, kw1, kw2, kw3,
314
  f"Generating request for {final_authority}...\nPlease wait, this may take a moment."
315
  )
316
-
317
  # 3. Call the local model and yield the final result
318
  generated_request = generate_request_local(final_authority, kw1, kw2, kw3)
319
  yield (
@@ -369,9 +361,9 @@ with gr.Blocks(css=reels_css, theme=gr.themes.Soft()) as demo:
369
  reel2 = gr.Textbox(label="Keyword 1", interactive=False, elem_id="reel-2", scale=1)
370
  reel3 = gr.Textbox(label="Keyword 2", interactive=False, elem_id="reel-3", scale=1)
371
  reel4 = gr.Textbox(label="Keyword 3", interactive=False, elem_id="reel-4", scale=1)
372
-
373
  pull_button = gr.Button("Generate a request", variant="primary", elem_id="pull-button")
374
-
375
  output_request = gr.Textbox(
376
  label="Generated FOI Request",
377
  lines=15,
@@ -386,4 +378,4 @@ with gr.Blocks(css=reels_css, theme=gr.themes.Soft()) as demo:
386
  )
387
 
388
  if __name__ == "__main__":
389
- demo.launch()
 
3
  import random
4
  import time
5
  import torch
6
+ import re # <-- NEW
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  import spaces
9
 
10
+ MODEL_ID = "HMC83/Wihtgar-650M-SFT-Requests_2-Merged"
11
 
12
  # --- Load Model and Tokenizer ---
13
  print("Loading model and tokenizer...")
 
180
  ALL_AUTHORITIES_FOR_SPIN = list(set([item["authority"] for item in FOI_COMBINATIONS]))
181
  ALL_KEYWORDS_FOR_SPIN = list(set(kw.strip() for item in FOI_COMBINATIONS for kw in item["keywords"].split(',')))
182
 
183
+ # --- Helper: clean model output into a numbered list starting at "1." ---
184
+ def clean_and_validate_output(text: str):
 
185
  """
186
+ Extract the main numbered list starting at '1.' and strip any closing signature lines.
187
+ Always returns cleaned text and a boolean flag (True = looks fine).
 
 
 
 
 
 
 
 
 
 
 
188
  """
189
+ # Keep everything from the first "1." onward, if present.
190
+ m = re.search(r'(?m)^\s*1\.\s', text)
191
+ body = text[m.start():].strip() if m else text.strip()
192
 
193
+ # Remove common signature lines at the end (best-effort).
194
+ body = re.sub(r'(?im)^\s*(yours.*|kind regards.*|regards.*)$', '', body).strip()
 
195
 
196
+ # If it doesn't contain at least one numbered point, it's still usable, but we mark as not strictly-valid.
197
+ is_valid = bool(re.search(r'(?m)^\s*\d+\.\s', body))
198
+ return body, is_valid
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ # --- Helper: wrap content in the FOI letter template ---
201
+ def wrap_in_letter(authority: str, body: str) -> str:
202
+ body = body.strip()
203
+ template = (
204
+ f"Dear {authority}\n\n"
205
+ "Please provide me with a copy of the following information:\n\n"
206
+ f"{body}\n\n"
207
+ "Yours faithfully,"
208
+ )
209
+ return template
210
 
211
  # --- Backend Function for Local Inference ---
212
  @spaces.GPU
 
219
  keyword_string = ", ".join(keywords)
220
  prompt = (
221
  "You are an expert at writing formal Freedom of Information requests to UK public authorities. "
222
+ f"Generate ONLY the numbered list of the specific information being requested, starting at '1.' "
223
+ f"for {authority}, using these keywords: {keyword_string}. "
224
+ "Do not include greetings or signatures."
225
  )
226
 
227
  max_retries = 2
 
233
  # Set generation parameters
234
  generation_params = {
235
  "max_new_tokens": 340,
236
+ "temperature": 0.0,
237
  "top_p": 0.95,
238
  "top_k": 50,
239
  "repetition_penalty": 1.1,
 
254
  if generated_text.startswith('.\n'):
255
  generated_text = generated_text[2:]
256
 
257
+ # Clean and validate the output
258
  cleaned_text, is_valid = clean_and_validate_output(generated_text)
259
+
260
+ # Wrap in the letter template regardless; validation just influences retry behavior
261
+ letter = wrap_in_letter(authority, cleaned_text)
262
+
263
  if is_valid:
264
+ return letter
265
  else:
266
+ print(f"Attempt {attempt + 1}/{max_retries}: Output lacked clear numbering. Retrying...")
267
 
268
  except Exception as e:
269
  print(f"Error during generation attempt {attempt + 1}/{max_retries}: {e}")
270
  if attempt == max_retries - 1:
271
  return f"An error occurred during text generation: {e}"
272
 
273
+ # If retries failed, return the best effort letter using the last cleaned text we had
274
+ return wrap_in_letter(authority, "1. [Unable to format automatically] Please restate the information requested.\n2. [Optional second point]")
 
275
 
276
  # --- Gradio UI and Spinning Logic ---
277
  def spin_the_reels():
 
290
  "Spinning..."
291
  )
292
  time.sleep(spin_interval)
293
+
294
  # 2. Select the final fixed combination
295
  final_combination = random.choice(FOI_COMBINATIONS)
296
  final_authority = final_combination["authority"]
297
+
298
  # Split, strip, and pad keywords to ensure we always have 3 for the UI
299
  keywords_list = [k.strip() for k in final_combination["keywords"].split(',')]
300
  keywords_list += [''] * (3 - len(keywords_list)) # Pad with empty strings if < 3
301
  kw1, kw2, kw3 = keywords_list[:3] # Take the first 3
302
+
303
  # Display the final reel values and a "Generating..." message
304
  yield (
305
  final_authority, kw1, kw2, kw3,
306
  f"Generating request for {final_authority}...\nPlease wait, this may take a moment."
307
  )
308
+
309
  # 3. Call the local model and yield the final result
310
  generated_request = generate_request_local(final_authority, kw1, kw2, kw3)
311
  yield (
 
361
  reel2 = gr.Textbox(label="Keyword 1", interactive=False, elem_id="reel-2", scale=1)
362
  reel3 = gr.Textbox(label="Keyword 2", interactive=False, elem_id="reel-3", scale=1)
363
  reel4 = gr.Textbox(label="Keyword 3", interactive=False, elem_id="reel-4", scale=1)
364
+
365
  pull_button = gr.Button("Generate a request", variant="primary", elem_id="pull-button")
366
+
367
  output_request = gr.Textbox(
368
  label="Generated FOI Request",
369
  lines=15,
 
378
  )
379
 
380
  if __name__ == "__main__":
381
+ demo.launch()