Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,10 +3,11 @@ import os
|
|
| 3 |
import random
|
| 4 |
import time
|
| 5 |
import torch
|
|
|
|
| 6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 7 |
import spaces
|
| 8 |
|
| 9 |
-
MODEL_ID = "HMC83/Wihtgar-650M-
|
| 10 |
|
| 11 |
# --- Load Model and Tokenizer ---
|
| 12 |
print("Loading model and tokenizer...")
|
|
@@ -179,46 +180,33 @@ FOI_COMBINATIONS = [
|
|
| 179 |
ALL_AUTHORITIES_FOR_SPIN = list(set([item["authority"] for item in FOI_COMBINATIONS]))
|
| 180 |
ALL_KEYWORDS_FOR_SPIN = list(set(kw.strip() for item in FOI_COMBINATIONS for kw in item["keywords"].split(',')))
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
def clean_and_validate_output(raw_text: str) -> tuple[str, bool]:
|
| 185 |
"""
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
It validates that the output contains essential markers ("Dear" and "[Your Name]").
|
| 189 |
-
If it detects that the model has started generating a second request, it truncates
|
| 190 |
-
the string after the first "[Your Name]".
|
| 191 |
-
|
| 192 |
-
Args:
|
| 193 |
-
raw_text: The raw string output from the language model.
|
| 194 |
-
|
| 195 |
-
Returns:
|
| 196 |
-
A tuple containing:
|
| 197 |
-
- The cleaned text.
|
| 198 |
-
- A boolean flag: True if the output is valid, False if it is malformed.
|
| 199 |
"""
|
| 200 |
-
|
| 201 |
-
|
|
|
|
| 202 |
|
| 203 |
-
#
|
| 204 |
-
|
| 205 |
-
return raw_text, False # Malformed, signal for regeneration.
|
| 206 |
|
| 207 |
-
#
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
# Check if a second request has started after the first one ended.
|
| 212 |
-
start_of_second_request_pos = raw_text.find(start_marker, end_of_first_request_index)
|
| 213 |
-
|
| 214 |
-
if start_of_second_request_pos != -1:
|
| 215 |
-
# If a second request is found, truncate to keep only the first one.
|
| 216 |
-
cleaned_text = raw_text[:end_of_first_request_index]
|
| 217 |
-
return cleaned_text, True
|
| 218 |
-
else:
|
| 219 |
-
# No second request found, the output is valid.
|
| 220 |
-
return raw_text, True
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
# --- Backend Function for Local Inference ---
|
| 224 |
@spaces.GPU
|
|
@@ -231,7 +219,9 @@ def generate_request_local(authority, kw1, kw2, kw3):
|
|
| 231 |
keyword_string = ", ".join(keywords)
|
| 232 |
prompt = (
|
| 233 |
"You are an expert at writing formal Freedom of Information requests to UK public authorities. "
|
| 234 |
-
f"
|
|
|
|
|
|
|
| 235 |
)
|
| 236 |
|
| 237 |
max_retries = 2
|
|
@@ -243,7 +233,7 @@ def generate_request_local(authority, kw1, kw2, kw3):
|
|
| 243 |
# Set generation parameters
|
| 244 |
generation_params = {
|
| 245 |
"max_new_tokens": 340,
|
| 246 |
-
"temperature": 0.
|
| 247 |
"top_p": 0.95,
|
| 248 |
"top_k": 50,
|
| 249 |
"repetition_penalty": 1.1,
|
|
@@ -264,22 +254,24 @@ def generate_request_local(authority, kw1, kw2, kw3):
|
|
| 264 |
if generated_text.startswith('.\n'):
|
| 265 |
generated_text = generated_text[2:]
|
| 266 |
|
| 267 |
-
#
|
| 268 |
cleaned_text, is_valid = clean_and_validate_output(generated_text)
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
| 270 |
if is_valid:
|
| 271 |
-
return
|
| 272 |
else:
|
| 273 |
-
print(f"Attempt {attempt + 1}/{max_retries}:
|
| 274 |
|
| 275 |
except Exception as e:
|
| 276 |
print(f"Error during generation attempt {attempt + 1}/{max_retries}: {e}")
|
| 277 |
if attempt == max_retries - 1:
|
| 278 |
return f"An error occurred during text generation: {e}"
|
| 279 |
|
| 280 |
-
# If the
|
| 281 |
-
return "
|
| 282 |
-
|
| 283 |
|
| 284 |
# --- Gradio UI and Spinning Logic ---
|
| 285 |
def spin_the_reels():
|
|
@@ -298,22 +290,22 @@ def spin_the_reels():
|
|
| 298 |
"Spinning..."
|
| 299 |
)
|
| 300 |
time.sleep(spin_interval)
|
| 301 |
-
|
| 302 |
# 2. Select the final fixed combination
|
| 303 |
final_combination = random.choice(FOI_COMBINATIONS)
|
| 304 |
final_authority = final_combination["authority"]
|
| 305 |
-
|
| 306 |
# Split, strip, and pad keywords to ensure we always have 3 for the UI
|
| 307 |
keywords_list = [k.strip() for k in final_combination["keywords"].split(',')]
|
| 308 |
keywords_list += [''] * (3 - len(keywords_list)) # Pad with empty strings if < 3
|
| 309 |
kw1, kw2, kw3 = keywords_list[:3] # Take the first 3
|
| 310 |
-
|
| 311 |
# Display the final reel values and a "Generating..." message
|
| 312 |
yield (
|
| 313 |
final_authority, kw1, kw2, kw3,
|
| 314 |
f"Generating request for {final_authority}...\nPlease wait, this may take a moment."
|
| 315 |
)
|
| 316 |
-
|
| 317 |
# 3. Call the local model and yield the final result
|
| 318 |
generated_request = generate_request_local(final_authority, kw1, kw2, kw3)
|
| 319 |
yield (
|
|
@@ -369,9 +361,9 @@ with gr.Blocks(css=reels_css, theme=gr.themes.Soft()) as demo:
|
|
| 369 |
reel2 = gr.Textbox(label="Keyword 1", interactive=False, elem_id="reel-2", scale=1)
|
| 370 |
reel3 = gr.Textbox(label="Keyword 2", interactive=False, elem_id="reel-3", scale=1)
|
| 371 |
reel4 = gr.Textbox(label="Keyword 3", interactive=False, elem_id="reel-4", scale=1)
|
| 372 |
-
|
| 373 |
pull_button = gr.Button("Generate a request", variant="primary", elem_id="pull-button")
|
| 374 |
-
|
| 375 |
output_request = gr.Textbox(
|
| 376 |
label="Generated FOI Request",
|
| 377 |
lines=15,
|
|
@@ -386,4 +378,4 @@ with gr.Blocks(css=reels_css, theme=gr.themes.Soft()) as demo:
|
|
| 386 |
)
|
| 387 |
|
| 388 |
if __name__ == "__main__":
|
| 389 |
-
demo.launch()
|
|
|
|
| 3 |
import random
|
| 4 |
import time
|
| 5 |
import torch
|
| 6 |
+
import re # <-- NEW
|
| 7 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 8 |
import spaces
|
| 9 |
|
| 10 |
+
MODEL_ID = "HMC83/Wihtgar-650M-SFT-Requests_2-Merged"
|
| 11 |
|
| 12 |
# --- Load Model and Tokenizer ---
|
| 13 |
print("Loading model and tokenizer...")
|
|
|
|
| 180 |
ALL_AUTHORITIES_FOR_SPIN = list(set([item["authority"] for item in FOI_COMBINATIONS]))
|
| 181 |
ALL_KEYWORDS_FOR_SPIN = list(set(kw.strip() for item in FOI_COMBINATIONS for kw in item["keywords"].split(',')))
|
| 182 |
|
| 183 |
+
# --- Helper: clean model output into a numbered list starting at "1." ---
|
| 184 |
+
def clean_and_validate_output(text: str):
|
|
|
|
| 185 |
"""
|
| 186 |
+
Extract the main numbered list starting at '1.' and strip any closing signature lines.
|
| 187 |
+
Always returns cleaned text and a boolean flag (True = looks fine).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
"""
|
| 189 |
+
# Keep everything from the first "1." onward, if present.
|
| 190 |
+
m = re.search(r'(?m)^\s*1\.\s', text)
|
| 191 |
+
body = text[m.start():].strip() if m else text.strip()
|
| 192 |
|
| 193 |
+
# Remove common signature lines at the end (best-effort).
|
| 194 |
+
body = re.sub(r'(?im)^\s*(yours.*|kind regards.*|regards.*)$', '', body).strip()
|
|
|
|
| 195 |
|
| 196 |
+
# If it doesn't contain at least one numbered point, it's still usable, but we mark as not strictly-valid.
|
| 197 |
+
is_valid = bool(re.search(r'(?m)^\s*\d+\.\s', body))
|
| 198 |
+
return body, is_valid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
+
# --- Helper: wrap content in the FOI letter template ---
|
| 201 |
+
def wrap_in_letter(authority: str, body: str) -> str:
|
| 202 |
+
body = body.strip()
|
| 203 |
+
template = (
|
| 204 |
+
f"Dear {authority}\n\n"
|
| 205 |
+
"Please provide me with a copy of the following information:\n\n"
|
| 206 |
+
f"{body}\n\n"
|
| 207 |
+
"Yours faithfully,"
|
| 208 |
+
)
|
| 209 |
+
return template
|
| 210 |
|
| 211 |
# --- Backend Function for Local Inference ---
|
| 212 |
@spaces.GPU
|
|
|
|
| 219 |
keyword_string = ", ".join(keywords)
|
| 220 |
prompt = (
|
| 221 |
"You are an expert at writing formal Freedom of Information requests to UK public authorities. "
|
| 222 |
+
f"Generate ONLY the numbered list of the specific information being requested, starting at '1.' "
|
| 223 |
+
f"for {authority}, using these keywords: {keyword_string}. "
|
| 224 |
+
"Do not include greetings or signatures."
|
| 225 |
)
|
| 226 |
|
| 227 |
max_retries = 2
|
|
|
|
| 233 |
# Set generation parameters
|
| 234 |
generation_params = {
|
| 235 |
"max_new_tokens": 340,
|
| 236 |
+
"temperature": 0.0,
|
| 237 |
"top_p": 0.95,
|
| 238 |
"top_k": 50,
|
| 239 |
"repetition_penalty": 1.1,
|
|
|
|
| 254 |
if generated_text.startswith('.\n'):
|
| 255 |
generated_text = generated_text[2:]
|
| 256 |
|
| 257 |
+
# Clean and validate the output
|
| 258 |
cleaned_text, is_valid = clean_and_validate_output(generated_text)
|
| 259 |
+
|
| 260 |
+
# Wrap in the letter template regardless; validation just influences retry behavior
|
| 261 |
+
letter = wrap_in_letter(authority, cleaned_text)
|
| 262 |
+
|
| 263 |
if is_valid:
|
| 264 |
+
return letter
|
| 265 |
else:
|
| 266 |
+
print(f"Attempt {attempt + 1}/{max_retries}: Output lacked clear numbering. Retrying...")
|
| 267 |
|
| 268 |
except Exception as e:
|
| 269 |
print(f"Error during generation attempt {attempt + 1}/{max_retries}: {e}")
|
| 270 |
if attempt == max_retries - 1:
|
| 271 |
return f"An error occurred during text generation: {e}"
|
| 272 |
|
| 273 |
+
# If retries failed, return the best effort letter using the last cleaned text we had
|
| 274 |
+
return wrap_in_letter(authority, "1. [Unable to format automatically] Please restate the information requested.\n2. [Optional second point]")
|
|
|
|
| 275 |
|
| 276 |
# --- Gradio UI and Spinning Logic ---
|
| 277 |
def spin_the_reels():
|
|
|
|
| 290 |
"Spinning..."
|
| 291 |
)
|
| 292 |
time.sleep(spin_interval)
|
| 293 |
+
|
| 294 |
# 2. Select the final fixed combination
|
| 295 |
final_combination = random.choice(FOI_COMBINATIONS)
|
| 296 |
final_authority = final_combination["authority"]
|
| 297 |
+
|
| 298 |
# Split, strip, and pad keywords to ensure we always have 3 for the UI
|
| 299 |
keywords_list = [k.strip() for k in final_combination["keywords"].split(',')]
|
| 300 |
keywords_list += [''] * (3 - len(keywords_list)) # Pad with empty strings if < 3
|
| 301 |
kw1, kw2, kw3 = keywords_list[:3] # Take the first 3
|
| 302 |
+
|
| 303 |
# Display the final reel values and a "Generating..." message
|
| 304 |
yield (
|
| 305 |
final_authority, kw1, kw2, kw3,
|
| 306 |
f"Generating request for {final_authority}...\nPlease wait, this may take a moment."
|
| 307 |
)
|
| 308 |
+
|
| 309 |
# 3. Call the local model and yield the final result
|
| 310 |
generated_request = generate_request_local(final_authority, kw1, kw2, kw3)
|
| 311 |
yield (
|
|
|
|
| 361 |
reel2 = gr.Textbox(label="Keyword 1", interactive=False, elem_id="reel-2", scale=1)
|
| 362 |
reel3 = gr.Textbox(label="Keyword 2", interactive=False, elem_id="reel-3", scale=1)
|
| 363 |
reel4 = gr.Textbox(label="Keyword 3", interactive=False, elem_id="reel-4", scale=1)
|
| 364 |
+
|
| 365 |
pull_button = gr.Button("Generate a request", variant="primary", elem_id="pull-button")
|
| 366 |
+
|
| 367 |
output_request = gr.Textbox(
|
| 368 |
label="Generated FOI Request",
|
| 369 |
lines=15,
|
|
|
|
| 378 |
)
|
| 379 |
|
| 380 |
if __name__ == "__main__":
|
| 381 |
+
demo.launch()
|