RayBe commited on
Commit
aa20016
·
verified ·
1 Parent(s): db6252e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -65
app.py CHANGED
@@ -9,102 +9,58 @@ model_name = "./t5-finetuned-final"
9
  tokenizer = T5Tokenizer.from_pretrained(model_name)
10
  model = T5ForConditionalGeneration.from_pretrained(model_name)
11
 
12
- # Move model to CPU (explicitly)
13
  device = torch.device("cpu")
14
  model.to(device)
15
 
16
  def extract_amount(input_text):
17
  """
18
- Extracts the amount from the input text using a robust regex.
19
- The negative lookahead (?!\S) ensures we stop capturing as soon as a non-space character appears.
 
 
20
  """
21
- # First try: match when the amount follows keywords like send, loan, pay, or transfer.
22
- amount_match = re.search(
23
- r'(?:send|loan|pay|transfer)\s*(\d+(?:\.\d+)?)(?!\S)',
24
- input_text,
25
- re.IGNORECASE
26
- )
27
- if not amount_match:
28
- # Fallback: match a number that is immediately followed by a currency symbol/abbreviation.
29
- amount_match = re.search(
30
- r'\b(\d+(?:\.\d+)?)\s*(?:AUD|USD|USDT|ETH|BTC|EUR)\b',
31
- input_text,
32
- re.IGNORECASE
33
- )
34
- if amount_match:
35
- return amount_match.group(1).strip()
36
  return None
37
 
38
- def fix_json_output(output):
39
- """
40
- Fixes common JSON formatting issues in the model's output .
41
- """
42
- # Remove trailing commas before closing braces/brackets
43
- output = re.sub(r',\s*([}\]])', r'\1', output)
44
- # Fix missing or extra quotes around keys
45
- output = re.sub(r'([{,])\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*:', r'\1"\2":', output)
46
- # Fix missing or extra quotes around string values
47
- output = re.sub(r':\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*([,}])', r':"\1"\2', output)
48
- return output
49
-
50
  def merge_json_with_amount(model_output, amount):
51
  """
52
- Updates only the 'amount' field in the model's JSON output,
53
- leaving all other fields as produced by the model.
54
  """
55
  try:
56
  data = json.loads(model_output)
57
- except json.JSONDecodeError:
58
- # If JSON parsing fails, attempt to fix common formatting issues.
59
- fixed_output = fix_json_output(model_output)
60
- try:
61
- data = json.loads(fixed_output)
62
- except json.JSONDecodeError:
63
- # If it still fails, return the model output unmodified.
64
- return model_output
65
-
66
  if amount:
67
  try:
68
- # Convert the cleaned string to a float
69
- data["amount"] = float(amount.strip())
70
- except ValueError:
71
- # In case conversion fails, keep the original string.
72
  data["amount"] = amount
73
-
74
  return json.dumps(data, ensure_ascii=False)
75
 
76
  def generate_command(input_command):
77
- # Extract the amount from the input
78
  amount = extract_amount(input_command)
79
 
80
- # Generate the JSON output using the model
81
  prompt = "extract: " + input_command
82
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
83
-
84
- output_ids = model.generate(
85
- input_ids,
86
- max_length=128, # Increased max_length to allow for complete JSON output
87
- num_beams=2, # Using beam search for better output quality
88
- early_stopping=True
89
- )
90
-
91
  model_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
92
 
93
- # Merge the model's output with the extracted amount.
94
- # The merge function only replaces the "amount" field, leaving all other keys intact.
95
- if amount:
96
- result = merge_json_with_amount(model_output, amount)
97
- else:
98
- result = model_output # Use the model's output as-is if no amount is found
99
-
100
  return result
101
 
102
  iface = gr.Interface(
103
  fn=generate_command,
104
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
105
  outputs=gr.Textbox(label="Extracted JSON Output"),
106
- title="T5 Fine-Tuned Command Extractor",
107
- description="Extracts details in JSON format and replaces the amount with the exact value from the input.",
108
  )
109
 
110
  if __name__ == "__main__":
 
9
  tokenizer = T5Tokenizer.from_pretrained(model_name)
10
  model = T5ForConditionalGeneration.from_pretrained(model_name)
11
 
12
+ # Move model to CPU explicitly
13
  device = torch.device("cpu")
14
  model.to(device)
15
 
16
  def extract_amount(input_text):
17
  """
18
+ Extracts the first number (with optional decimals) from the input text.
19
+ For example, in:
20
+ "Should I send 2659.53464 EUR to my wife today?"
21
+ it returns the string "2659.53464".
22
  """
23
+ match = re.search(r'\b(\d+(?:\.\d+)?)\b', input_text)
24
+ if match:
25
+ return match.group(1)
 
 
 
 
 
 
 
 
 
 
 
 
26
  return None
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def merge_json_with_amount(model_output, amount):
29
  """
30
+ Parses the model's JSON output and overrides the "amount" key
31
+ with the manually extracted value.
32
  """
33
  try:
34
  data = json.loads(model_output)
35
+ except Exception:
36
+ data = {}
 
 
 
 
 
 
 
37
  if amount:
38
  try:
39
+ data["amount"] = float(amount)
40
+ except Exception:
 
 
41
  data["amount"] = amount
 
42
  return json.dumps(data, ensure_ascii=False)
43
 
44
  def generate_command(input_command):
45
+ # Manually extract the amount from the input.
46
  amount = extract_amount(input_command)
47
 
48
+ # Generate the JSON output from the model.
49
  prompt = "extract: " + input_command
50
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
51
+ output_ids = model.generate(input_ids, max_length=128, num_beams=2, early_stopping=True)
 
 
 
 
 
 
 
52
  model_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
53
 
54
+ # Merge the manually extracted amount into the model output.
55
+ result = merge_json_with_amount(model_output, amount)
 
 
 
 
 
56
  return result
57
 
58
  iface = gr.Interface(
59
  fn=generate_command,
60
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
61
  outputs=gr.Textbox(label="Extracted JSON Output"),
62
+ title="T5 Command Extractor",
63
+ description="The model provides action, currency, and recipient. The amount is manually extracted from the input."
64
  )
65
 
66
  if __name__ == "__main__":