HMC83's picture
Update app.py
4ed1f95 verified
raw
history blame
10.7 kB
import gradio as gr
import os
import random
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces
MODEL_ID = "HMC83/chaurus_360M_requests"
# --- Load Model and Tokenizer ---
print("Loading model and tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
model = None
tokenizer = None
# --- Data for the Reels ---
# A list of authority and keyword combinations.
FOI_COMBINATIONS = [
{"authority": "Isle of Wight Council", "keywords": "Yeti, habitat, evidence"},
{"authority": "The Royal Parks", "keywords": "squirrel uprising, acorn stockpiles, intelligence"},
{"authority": "Scottish Government", "keywords": "the Loch Ness Monster, research, funding"},
{"authority": "Westminster City Council", "keywords": "noise complaints, nightlife, statistics"},
{"authority": "Leeds Teaching Hospitals NHS Trust", "keywords": "agency staff, temporary costs, recruitment spend"},
{"authority": "Ministry of Justice", "keywords": "court delays, case backlogs, administrative costs"},
{"authority": "Forestry Commission", "keywords": "timber sales, revenue targets, environmental impact"},
{"authority": "Ofcom", "keywords": "broadcasting complaints, investigation procedures, penalty decisions"},
{"authority": "Brighton and Hove Council", "keywords": "seafront, event licensing, revenue generation"},
{"authority": "Chester West Council", "keywords": "tourism, marketing spend, economic impact"},
{"authority": "Durham County Council", "keywords": "mining subsidence, compensation claims, repair costs"},
{"authority": "Warwickshire Police", "keywords": "police stations, closure plans, public consultations"},
{"authority": "West Mercia Police", "keywords": "cybercrime, investigation resources, training costs"},
{"authority": "Environment Agency", "keywords": "flood defenses, maintenance spending, effectiveness assessments"},
{"authority": "Historic England", "keywords": "listed buildings, consent applications, enforcement cases"},
{"authority": "Highways England", "keywords": "motorway maintenance, contractor payments, performance indicators"},
{"authority": "Network Rail", "keywords": "signal failures, delay minutes, compensation costs"},
{"authority": "Civil Aviation Authority", "keywords": "airline complaints, resolution times, enforcement action"},
{"authority": "Maritime and Coastguard Agency", "keywords": "rescue operations, helicopter costs, response times"},
{"authority": "Planning Inspectorate", "keywords": "appeal decisions, processing times, outcome statistics"},
{"authority": "Ofgem", "keywords": "energy company, penalties imposed, consumer complaints"},
{"authority": "Ofwat", "keywords": "water companies, price reviews, performance monitoring"},
{"authority": "Ofcom", "keywords": "broadcasting complaints, investigation procedures, penalty decisions"},
{"authority": "Financial Conduct Authority", "keywords": "financial services, enforcement cases, penalty amounts"},
{"authority": "Competition and Markets Authority", "keywords": "merger investigations, market studies, intervention costs"},
{"authority": "Gambling Commission", "keywords": "license suspensions, operator penalties, compliance monitoring"},
{"authority": "Information Commissioner's Office", "keywords": "data breaches, penalty notices, investigation outcomes"},
{"authority": "Electoral Commission", "keywords": "campaign spending, compliance checks, penalty procedures"},
{"authority": "HM Revenue and Customs", "keywords": "tax investigations, recovery amounts, compliance costs"},
{"authority": "HM Treasury", "keywords": "departmental budgets, spending reviews, efficiency targets"},
{"authority": "Cabinet Office", "keywords": "government consultancy, external advisors, procurement costs"},
{"authority": "Home Office", "keywords": "immigration appeals, processing times, detention costs"},
{"authority": "Department for Transport", "keywords": "railway subsidies, franchise performance, punctuality targets"},
{"authority": "Department for Education", "keywords": "academy conversions, funding allocations, performance data"},
{"authority": "Department for Work and Pensions", "keywords": "benefit sanctions, appeal success, administrative costs"},
{"authority": "Department of Health", "keywords": "NHS funding, allocation formulas, performance targets"},
{"authority": "Department for Environment", "keywords": "environmental fines, pollution incidents, enforcement costs"},
]
# Create lists for the spinning animation from the combinations above
ALL_AUTHORITIES_FOR_SPIN = list(set([item["authority"] for item in FOI_COMBINATIONS]))
ALL_KEYWORDS_FOR_SPIN = list(set(kw.strip() for item in FOI_COMBINATIONS for kw in item["keywords"].split(',')))
# --- Backend Function for Local Inference ---
@spaces.GPU
def generate_request_local(authority, kw1, kw2, kw3):
"""Generates a request using the locally loaded transformer model."""
if not model or not tokenizer:
return "Error: Model is not loaded. Please check the Space logs for details."
keywords = [kw for kw in [kw1, kw2, kw3] if kw]
keyword_string = ", ".join(keywords)
instruction = f"Generate a formal Freedom of Information request to {authority} using these keywords: {keyword_string}"
prompt = (
"You are an expert at writing formal Freedom of Information requests to UK public authorities. "
"Write clear, specific, professional requests that comply with FOI requirements and use accessible language. "
f"### Instruction:\n{instruction}\n\n### Response:\n"
)
try:
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
if 'token_type_ids' in inputs:
del inputs['token_type_ids']
# Set generation parameters
generation_params = {
"max_new_tokens": 250,
"temperature": 0.13,
"top_p": 0.95,
"top_k": 50,
"repetition_penalty": 1.15,
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id
}
# Generate text sequences
output_sequences = model.generate(**inputs, **generation_params)
# Decode the generated text (this part correctly decodes only the new tokens)
generated_text = tokenizer.decode(
output_sequences[0][len(inputs["input_ids"][0]):],
skip_special_tokens=True
).strip()
# Remove artifact if present
if generated_text.startswith('.\n'):
generated_text = generated_text[2:]
return generated_text
except Exception as e:
print(f"Error during generation: {e}")
return f"An error occurred during text generation: {e}"
# --- Gradio UI and Spinning Logic ---
def spin_the_reels():
"""A generator function that simulates spinning reels and then calls the model."""
# 1. Simulate the spinning effect
spin_duration = 2.0 # seconds
spin_interval = 0.05 # update interval
start_time = time.time()
while time.time() - start_time < spin_duration:
# Yield random values for each reel to create the spinning illusion
yield (
random.choice(ALL_AUTHORITIES_FOR_SPIN),
random.choice(ALL_KEYWORDS_FOR_SPIN),
random.choice(ALL_KEYWORDS_FOR_SPIN),
random.choice(ALL_KEYWORDS_FOR_SPIN),
"Spinning..."
)
time.sleep(spin_interval)
# 2. Select the final fixed combination
final_combination = random.choice(FOI_COMBINATIONS)
final_authority = final_combination["authority"]
# Split, strip, and pad keywords to ensure we always have 3 for the UI
keywords_list = [k.strip() for k in final_combination["keywords"].split(',')]
keywords_list += [''] * (3 - len(keywords_list)) # Pad with empty strings if < 3
kw1, kw2, kw3 = keywords_list[:3] # Take the first 3
# Display the final reel values and a "Generating..." message
yield (
final_authority, kw1, kw2, kw3,
f"Generating request for {final_authority}...\nPlease wait, this may take a moment."
)
# 3. Call the local model and yield the final result
generated_request = generate_request_local(final_authority, kw1, kw2, kw3)
yield (
final_authority, kw1, kw2, kw3,
generated_request
)
# --- CSS for Styling ---
# Added min-width to reduce UI flickering on text change
reels_css = """
#reels-container {
display: flex;
gap: 1rem;
margin-bottom: 1rem;
}
#reels-container .gradio-textbox {
text-align: center;
border: 4px solid #f59e0b;
border-radius: 12px;
background-color: #fef3c7;
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
min-width: 150px; /* Prevents resizing/flickering during spin */
}
#reels-container .gradio-textbox input {
font-size: 1.25rem !important;
font-weight: bold;
color: #b45309;
text-align: center;
}
#pull-button {
font-size: 1.5rem !important;
font-weight: bold;
padding: 1rem !important;
box-shadow: 0 8px 15px rgba(0,0,0,0.2);
transition: all 0.2s ease-in-out;
}
#pull-button:hover {
transform: translateY(-2px);
box-shadow: 0 10px 20px rgba(0,0,0,0.3);
}
"""
# --- Build the Gradio App ---
with gr.Blocks(css=reels_css, theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# FOI Request-O-Matic 🎰
Press the button to generate a Freedom of Information request using chaurus_360M.
"""
)
with gr.Row(elem_id="reels-container"):
reel1 = gr.Textbox(label="Authority", interactive=False, elem_id="reel-1", scale=2)
reel2 = gr.Textbox(label="Keyword 1", interactive=False, elem_id="reel-2", scale=1)
reel3 = gr.Textbox(label="Keyword 2", interactive=False, elem_id="reel-3", scale=1)
reel4 = gr.Textbox(label="Keyword 3", interactive=False, elem_id="reel-4", scale=1)
pull_button = gr.Button("Generate a request", variant="primary", elem_id="pull-button")
output_request = gr.Textbox(
label="Generated FOI Request",
lines=15,
interactive=True,
show_copy_button=True
)
pull_button.click(
fn=spin_the_reels,
inputs=[],
outputs=[reel1, reel2, reel3, reel4, output_request]
)
if __name__ == "__main__":
demo.launch()