RayBe commited on
Commit
56991a9
·
verified ·
1 Parent(s): 727d394

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -18
app.py CHANGED
@@ -16,13 +16,23 @@ model.to(device)
16
  def extract_amount(input_text):
17
  """
18
  Extracts the amount from the input text using a robust regex.
 
19
  """
20
- # Improved regex to match amounts preceded by keywords or currency symbols
21
- amount_match = re.search(r'(?:send|loan|pay|transfer)\s*(\d+(?:\.\d+)?)', input_text, re.IGNORECASE)
 
 
 
 
22
  if not amount_match:
23
- amount_match = re.search(r'\b(\d+(?:\.\d+)?)\s*(?:AUD|USD|USDT|ETH|BTC)\b', input_text, re.IGNORECASE)
 
 
 
 
 
24
  if amount_match:
25
- return amount_match.group(1)
26
  return None
27
 
28
  def fix_json_output(output):
@@ -43,21 +53,26 @@ def merge_json_with_amount(model_output, amount):
43
  leaving all other fields as produced by the model.
44
  """
45
  try:
46
- # Attempt to load the model output directly as JSON.
47
  data = json.loads(model_output)
48
  except json.JSONDecodeError:
49
- # If parsing fails, just return the model output unmodified.
50
- # (You might choose to log an error here.)
51
- return model_output
 
 
 
 
52
 
53
- # Replace (or add) the "amount" field using the extracted amount.
54
  if amount:
55
- data["amount"] = float(amount) if '.' in amount else int(amount)
 
 
 
 
 
56
 
57
- # Dump back to JSON without altering other keys.
58
  return json.dumps(data, ensure_ascii=False)
59
 
60
-
61
  def generate_command(input_command):
62
  # Extract the amount from the input
63
  amount = extract_amount(input_command)
@@ -66,18 +81,17 @@ def generate_command(input_command):
66
  prompt = "extract: " + input_command
67
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
68
 
69
- # Generate output with increased max_length for complete JSON
70
  output_ids = model.generate(
71
  input_ids,
72
- max_length=128, # Increased to allow complete JSON output
73
- num_beams=2, # Reduced for faster inference
74
  early_stopping=True
75
  )
76
 
77
- # Decode the model's output
78
  model_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
79
 
80
- # Merge the model's output with the extracted amount
 
81
  if amount:
82
  result = merge_json_with_amount(model_output, amount)
83
  else:
@@ -94,4 +108,4 @@ iface = gr.Interface(
94
  )
95
 
96
  if __name__ == "__main__":
97
- iface.launch()
 
16
  def extract_amount(input_text):
17
  """
18
  Extracts the amount from the input text using a robust regex.
19
+ The negative lookahead (?!\S) ensures we stop capturing as soon as a non-space character appears.
20
  """
21
+ # First try: match when the amount follows keywords like send, loan, pay, or transfer.
22
+ amount_match = re.search(
23
+ r'(?:send|loan|pay|transfer)\s*(\d+(?:\.\d+)?)(?!\S)',
24
+ input_text,
25
+ re.IGNORECASE
26
+ )
27
  if not amount_match:
28
+ # Fallback: match a number that is immediately followed by a currency symbol/abbreviation.
29
+ amount_match = re.search(
30
+ r'\b(\d+(?:\.\d+)?)\s*(?:AUD|USD|USDT|ETH|BTC|EUR)\b',
31
+ input_text,
32
+ re.IGNORECASE
33
+ )
34
  if amount_match:
35
+ return amount_match.group(1).strip()
36
  return None
37
 
38
  def fix_json_output(output):
 
53
  leaving all other fields as produced by the model.
54
  """
55
  try:
 
56
  data = json.loads(model_output)
57
  except json.JSONDecodeError:
58
+ # If JSON parsing fails, attempt to fix common formatting issues.
59
+ fixed_output = fix_json_output(model_output)
60
+ try:
61
+ data = json.loads(fixed_output)
62
+ except json.JSONDecodeError:
63
+ # If it still fails, return the model output unmodified.
64
+ return model_output
65
 
 
66
  if amount:
67
+ try:
68
+ # Convert the cleaned string to a float
69
+ data["amount"] = float(amount.strip())
70
+ except ValueError:
71
+ # In case conversion fails, keep the original string.
72
+ data["amount"] = amount
73
 
 
74
  return json.dumps(data, ensure_ascii=False)
75
 
 
76
  def generate_command(input_command):
77
  # Extract the amount from the input
78
  amount = extract_amount(input_command)
 
81
  prompt = "extract: " + input_command
82
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
83
 
 
84
  output_ids = model.generate(
85
  input_ids,
86
+ max_length=128, # Increased max_length to allow for complete JSON output
87
+ num_beams=2, # Using beam search for better output quality
88
  early_stopping=True
89
  )
90
 
 
91
  model_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
92
 
93
+ # Merge the model's output with the extracted amount.
94
+ # The merge function only replaces the "amount" field, leaving all other keys intact.
95
  if amount:
96
  result = merge_json_with_amount(model_output, amount)
97
  else:
 
108
  )
109
 
110
  if __name__ == "__main__":
111
+ iface.launch()