RayBe commited on
Commit
5ddb235
·
verified ·
1 Parent(s): a09d2a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -36
app.py CHANGED
@@ -18,55 +18,56 @@ if torch.cuda.is_available():
18
  model.half() # Use half-precision for faster computation
19
  try:
20
  model = torch.compile(model) # PyTorch 2.0+ optimization
21
- except:
22
  pass # Ignore if torch.compile is not available
23
 
24
- def sanitize_amount(output):
25
  """
26
- Sanitizes the amount field to ensure it is correctly formatted.
 
 
 
 
 
27
  """
28
- # Fix malformed amounts like "46307.0" -> "4630.07" or "4630327.0" -> "463032.07"
29
- def fix_malformed_amount(match):
30
- full_match = match.group(0)
31
- integer_part = match.group(1)
32
- decimal_part = match.group(2)
33
- return f"{integer_part}.{decimal_part}" # Reconstruct the correct format
34
-
35
- # Match numbers with misplaced decimal points
36
- output = re.sub(r'(\d+)(\d{2})\.0', fix_malformed_amount, output)
37
- return output
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def generate_command(input_command):
40
- """
41
- Generates the command and ensures the exact amount is displayed without changes.
42
- """
43
  prompt = "extract: " + input_command
44
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
45
-
46
- # Generate output from the model
47
  output_ids = model.generate(
48
  input_ids,
49
  max_length=64, # Reduced for speed
50
  num_beams=3, # Lowered from 5 to 3 for faster output
51
  early_stopping=True
52
  )
 
53
  result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
54
-
55
- # Sanitize the output to fix malformed amounts
56
- sanitized_result = sanitize_amount(result)
57
-
58
- try:
59
- # Attempt to parse the sanitized result as JSON
60
- data = json.loads(sanitized_result)
61
-
62
- # Convert numeric amounts to strings to preserve exact formatting
63
- if isinstance(data.get("amount"), (int, float)):
64
- data["amount"] = str(data["amount"])
65
-
66
- return json.dumps(data, ensure_ascii=False) # Return as JSON string
67
- except json.JSONDecodeError:
68
- # If not valid JSON, return the raw sanitized output
69
- return sanitized_result
70
 
71
  # Create a Gradio interface
72
  iface = gr.Interface(
@@ -77,6 +78,5 @@ iface = gr.Interface(
77
  description="Enter a command, and the fine-tuned T5 model will extract relevant details in JSON format.",
78
  )
79
 
80
- # Launch the app
81
  if __name__ == "__main__":
82
- iface.launch()
 
18
  model.half() # Use half-precision for faster computation
19
  try:
20
  model = torch.compile(model) # PyTorch 2.0+ optimization
21
+ except Exception:
22
  pass # Ignore if torch.compile is not available
23
 
24
+ def correct_amount_format(output):
25
  """
26
+ This function attempts to fix the numeric formatting issues in the generated output:
27
+ 1. It replaces a comma used as a decimal separator (i.e. followed by exactly two digits) with a period.
28
+ 2. It converts the number to a float and rounds it to two decimal places.
29
+
30
+ If the output is valid JSON, it will update the "amount" field accordingly.
31
+ Otherwise, it falls back to a regex-based fix.
32
  """
33
+ try:
34
+ # Try to parse the output as JSON
35
+ data = json.loads(output)
36
+ if "amount" in data and isinstance(data["amount"], str):
37
+ # Replace a comma that is likely a decimal separator (e.g., "10,50" -> "10.50")
38
+ amount_str = re.sub(r'(\d+),(\d{2})\b', r'\1.\2', data["amount"])
39
+ try:
40
+ # Convert to float, round to two decimals, then reformat
41
+ num = float(amount_str)
42
+ rounded = round(num, 2)
43
+ data["amount"] = "{:.2f}".format(rounded)
44
+ except ValueError:
45
+ # If conversion fails, leave the original value
46
+ pass
47
+ return json.dumps(data, ensure_ascii=False)
48
+ except json.JSONDecodeError:
49
+ # Fallback if output is not valid JSON:
50
+ # Replace commas used as decimal separators (only if followed by exactly 2 digits)
51
+ output = re.sub(r'(\d+),(\d{2})\b', r'\1.\2', output)
52
+ # Fallback: truncate any extra digits (note: this does not round)
53
+ output = re.sub(r'(\d+\.\d{2})\d+', r'\1', output)
54
+ return output
55
 
56
  def generate_command(input_command):
 
 
 
57
  prompt = "extract: " + input_command
58
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
59
+
 
60
  output_ids = model.generate(
61
  input_ids,
62
  max_length=64, # Reduced for speed
63
  num_beams=3, # Lowered from 5 to 3 for faster output
64
  early_stopping=True
65
  )
66
+
67
  result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
68
+ # Apply the updated post-processing to correct the amount formatting
69
+ result = correct_amount_format(result)
70
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # Create a Gradio interface
73
  iface = gr.Interface(
 
78
  description="Enter a command, and the fine-tuned T5 model will extract relevant details in JSON format.",
79
  )
80
 
 
81
  if __name__ == "__main__":
82
+ iface.launch()