RayBe commited on
Commit
0db5562
·
verified ·
1 Parent(s): 56516af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -23
app.py CHANGED
@@ -19,6 +19,16 @@ if torch.cuda.is_available():
19
  except:
20
  pass
21
 
 
 
 
 
 
 
 
 
 
 
22
  def fix_json_output(output):
23
  """
24
  Fixes common JSON formatting issues in the model's output.
@@ -31,40 +41,29 @@ def fix_json_output(output):
31
  output = re.sub(r':\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*([,}])', r':"\1"\2', output)
32
  return output
33
 
34
- def correct_amount_format(output):
35
  """
36
- Corrects amount formatting in the JSON output.
37
  """
38
  try:
39
  # Fix JSON formatting issues
40
  output = fix_json_output(output)
41
  data = json.loads(output)
42
 
43
- def correct_value(value):
44
- if isinstance(value, str):
45
- # Fix amounts with multiple decimal points (e.g., 3140.98.0 3140.98)
46
- if re.match(r'^\d+\.\d+\.\d+$', value):
47
- value = value.split('.')[0] + '.' + value.split('.')[1]
48
- # Remove trailing .0 if it's not part of the original amount
49
- if re.match(r'^\d+\.0$', value):
50
- value = value.split('.')[0]
51
- return value
52
-
53
- # Correct each value in the JSON data
54
- if isinstance(data, dict):
55
- for key in data:
56
- data[key] = correct_value(data[key])
57
- elif isinstance(data, list):
58
- for i in range(len(data)):
59
- data[i] = correct_value(data[i])
60
 
61
  return json.dumps(data, ensure_ascii=False)
62
  except json.JSONDecodeError:
63
- # Fallback for invalid JSON: basic corrections
64
- output = re.sub(r'(\d+),(\d+)\b', r'\1.\2', output)
65
  return output
66
 
67
  def generate_command(input_command):
 
 
 
 
68
  prompt = "extract: " + input_command
69
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
70
 
@@ -76,7 +75,11 @@ def generate_command(input_command):
76
  )
77
 
78
  result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
79
- result = correct_amount_format(result)
 
 
 
 
80
  return result
81
 
82
  iface = gr.Interface(
@@ -84,7 +87,7 @@ iface = gr.Interface(
84
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
85
  outputs=gr.Textbox(label="Extracted JSON Output"),
86
  title="T5 Fine-Tuned Command Extractor",
87
- description="Extracts details in JSON format with exact amount preservation.",
88
  )
89
 
90
  if __name__ == "__main__":
 
19
  except:
20
  pass
21
 
22
+ def extract_amount(input_text):
23
+ """
24
+ Extracts the amount from the input text using a robust regex.
25
+ """
26
+ # Regex to match any valid number (integers or decimals)
27
+ amount_match = re.search(r'\d+(?:\.\d+)?', input_text)
28
+ if amount_match:
29
+ return amount_match.group(0)
30
+ return None
31
+
32
  def fix_json_output(output):
33
  """
34
  Fixes common JSON formatting issues in the model's output.
 
41
  output = re.sub(r':\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*([,}])', r':"\1"\2', output)
42
  return output
43
 
44
+ def replace_amount_in_json(output, amount):
45
  """
46
+ Replaces the amount in the model's JSON output with the extracted amount.
47
  """
48
  try:
49
  # Fix JSON formatting issues
50
  output = fix_json_output(output)
51
  data = json.loads(output)
52
 
53
+ # Replace the amount field if it exists
54
+ if "amount" in data:
55
+ data["amount"] = float(amount) if '.' in amount else int(amount)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  return json.dumps(data, ensure_ascii=False)
58
  except json.JSONDecodeError:
59
+ # Fallback for invalid JSON: return the original output
 
60
  return output
61
 
62
  def generate_command(input_command):
63
+ # Extract the amount from the input
64
+ amount = extract_amount(input_command)
65
+
66
+ # Generate the JSON output using the model
67
  prompt = "extract: " + input_command
68
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
69
 
 
75
  )
76
 
77
  result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
78
+
79
+ # Replace the model's amount with the extracted amount
80
+ if amount:
81
+ result = replace_amount_in_json(result, amount)
82
+
83
  return result
84
 
85
  iface = gr.Interface(
 
87
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
88
  outputs=gr.Textbox(label="Extracted JSON Output"),
89
  title="T5 Fine-Tuned Command Extractor",
90
+ description="Extracts details in JSON format and replaces the amount with the exact value from the input.",
91
  )
92
 
93
  if __name__ == "__main__":