HMC83 commited on
Commit
d5afe66
·
verified ·
1 Parent(s): 493381a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import spaces
8
 
9
- MODEL_ID = "HMC83/Wightgar-650M-RequestWriter-DPO"
10
 
11
  # --- Load Model and Tokenizer ---
12
  print("Loading model and tokenizer...")
@@ -107,9 +107,13 @@ def generate_request_local(authority, kw1, kw2, kw3):
107
 
108
  keywords = [kw for kw in [kw1, kw2, kw3] if kw]
109
  keyword_string = ", ".join(keywords)
 
 
 
110
  prompt = (
111
- "You are an expert at writing formal Freedom of Information requests to UK public authorities."
112
- f"""Generate a formal Freedom of Information request to {authority} using these keywords: {keyword_string}"""
 
113
  )
114
 
115
  try:
@@ -130,7 +134,7 @@ def generate_request_local(authority, kw1, kw2, kw3):
130
  # Generate text sequences
131
  output_sequences = model.generate(**inputs, **generation_params)
132
 
133
- # Decode the generated text
134
  generated_text = tokenizer.decode(
135
  output_sequences[0][len(inputs["input_ids"][0]):],
136
  skip_special_tokens=True
@@ -141,12 +145,10 @@ def generate_request_local(authority, kw1, kw2, kw3):
141
  generated_text = generated_text[2:]
142
 
143
  return generated_text
144
-
145
  except Exception as e:
146
  print(f"Error during generation: {e}")
147
  return f"An error occurred during text generation: {e}"
148
 
149
-
150
  # --- Gradio UI and Spinning Logic ---
151
  def spin_the_reels():
152
  """A generator function that simulates spinning reels and then calls the model."""
@@ -168,7 +170,6 @@ def spin_the_reels():
168
  # 2. Select the final fixed combination
169
  final_combination = random.choice(FOI_COMBINATIONS)
170
  final_authority = final_combination["authority"]
171
-
172
  # Split, strip, and pad keywords to ensure we always have 3 for the UI
173
  keywords_list = [k.strip() for k in final_combination["keywords"].split(',')]
174
  keywords_list += [''] * (3 - len(keywords_list)) # Pad with empty strings if < 3
@@ -237,7 +238,6 @@ with gr.Blocks(css=reels_css, theme=gr.themes.Soft()) as demo:
237
  reel4 = gr.Textbox(label="Keyword 3", interactive=False, elem_id="reel-4", scale=1)
238
 
239
  pull_button = gr.Button("Generate a request", variant="primary", elem_id="pull-button")
240
-
241
  output_request = gr.Textbox(
242
  label="Generated FOI Request",
243
  lines=15,
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import spaces
8
 
9
+ MODEL_ID = "HMC83/Wihtgar-650M-DPO-Requests-2"
10
 
11
  # --- Load Model and Tokenizer ---
12
  print("Loading model and tokenizer...")
 
107
 
108
  keywords = [kw for kw in [kw1, kw2, kw3] if kw]
109
  keyword_string = ", ".join(keywords)
110
+
111
+ instruction = f"Generate a formal Freedom of Information request to {authority} using these keywords: {keyword_string}"
112
+
113
  prompt = (
114
+ "You are an expert at writing formal Freedom of Information requests to UK public authorities. "
115
+ "Write clear, specific, professional requests that comply with FOI requirements and use accessible language. "
116
+ f"### Instruction:\n{instruction}\n\n### Response:\n"
117
  )
118
 
119
  try:
 
134
  # Generate text sequences
135
  output_sequences = model.generate(**inputs, **generation_params)
136
 
137
+ # Decode the generated text (this part correctly decodes only the new tokens)
138
  generated_text = tokenizer.decode(
139
  output_sequences[0][len(inputs["input_ids"][0]):],
140
  skip_special_tokens=True
 
145
  generated_text = generated_text[2:]
146
 
147
  return generated_text
 
148
  except Exception as e:
149
  print(f"Error during generation: {e}")
150
  return f"An error occurred during text generation: {e}"
151
 
 
152
  # --- Gradio UI and Spinning Logic ---
153
  def spin_the_reels():
154
  """A generator function that simulates spinning reels and then calls the model."""
 
170
  # 2. Select the final fixed combination
171
  final_combination = random.choice(FOI_COMBINATIONS)
172
  final_authority = final_combination["authority"]
 
173
  # Split, strip, and pad keywords to ensure we always have 3 for the UI
174
  keywords_list = [k.strip() for k in final_combination["keywords"].split(',')]
175
  keywords_list += [''] * (3 - len(keywords_list)) # Pad with empty strings if < 3
 
238
  reel4 = gr.Textbox(label="Keyword 3", interactive=False, elem_id="reel-4", scale=1)
239
 
240
  pull_button = gr.Button("Generate a request", variant="primary", elem_id="pull-button")
 
241
  output_request = gr.Textbox(
242
  label="Generated FOI Request",
243
  lines=15,