RayBe commited on
Commit
ded2ee6
·
verified ·
1 Parent(s): aa20016

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -28
app.py CHANGED
@@ -14,44 +14,34 @@ device = torch.device("cpu")
14
  model.to(device)
15
 
16
  def extract_amount(input_text):
17
- """
18
- Extracts the first number (with optional decimals) from the input text.
19
- For example, in:
20
- "Should I send 2659.53464 EUR to my wife today?"
21
- it returns the string "2659.53464".
22
- """
23
  match = re.search(r'\b(\d+(?:\.\d+)?)\b', input_text)
24
- if match:
25
- return match.group(1)
26
- return None
27
 
28
  def merge_json_with_amount(model_output, amount):
29
- """
30
- Parses the model's JSON output and overrides the "amount" key
31
- with the manually extracted value.
32
- """
33
  try:
34
- data = json.loads(model_output)
35
- except Exception:
36
- data = {}
 
37
  if amount:
38
  try:
39
- data["amount"] = float(amount)
40
- except Exception:
41
- data["amount"] = amount
 
42
  return json.dumps(data, ensure_ascii=False)
43
 
44
  def generate_command(input_command):
45
- # Manually extract the amount from the input.
46
- amount = extract_amount(input_command)
 
 
47
 
48
- # Generate the JSON output from the model.
49
- prompt = "extract: " + input_command
50
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
51
- output_ids = model.generate(input_ids, max_length=128, num_beams=2, early_stopping=True)
52
  model_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
53
-
54
- # Merge the manually extracted amount into the model output.
55
  result = merge_json_with_amount(model_output, amount)
56
  return result
57
 
@@ -60,7 +50,7 @@ iface = gr.Interface(
60
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
61
  outputs=gr.Textbox(label="Extracted JSON Output"),
62
  title="T5 Command Extractor",
63
- description="The model provides action, currency, and recipient. The amount is manually extracted from the input."
64
  )
65
 
66
  if __name__ == "__main__":
 
14
  model.to(device)
15
 
16
  def extract_amount(input_text):
17
+ """Extracts the first numeric value (with decimals) from the input text."""
 
 
 
 
 
18
  match = re.search(r'\b(\d+(?:\.\d+)?)\b', input_text)
19
+ return match.group(1) if match else None
 
 
20
 
21
  def merge_json_with_amount(model_output, amount):
22
+ """Ensures the extracted amount is inserted into the model's output JSON."""
 
 
 
23
  try:
24
+ data = json.loads(model_output) # Try parsing the model output
25
+ except json.JSONDecodeError:
26
+ return model_output # Return raw model output if it's not valid JSON
27
+
28
  if amount:
29
  try:
30
+ data["amount"] = float(amount) # Convert to float
31
+ except ValueError:
32
+ data["amount"] = amount # Keep as string if conversion fails
33
+
34
  return json.dumps(data, ensure_ascii=False)
35
 
36
  def generate_command(input_command):
37
+ """Processes the input, extracts the amount manually, and runs the model."""
38
+ amount = extract_amount(input_command) # Extract amount
39
+
40
+ input_ids = tokenizer("extract: " + input_command, return_tensors="pt").input_ids.to(device)
41
 
42
+ output_ids = model.generate(input_ids, max_length=128, num_beams=1, early_stopping=True) # Faster decoding
 
 
 
43
  model_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
44
+
 
45
  result = merge_json_with_amount(model_output, amount)
46
  return result
47
 
 
50
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
51
  outputs=gr.Textbox(label="Extracted JSON Output"),
52
  title="T5 Command Extractor",
53
+ description="The model extracts action, currency, and recipient. The amount is manually extracted."
54
  )
55
 
56
  if __name__ == "__main__":