RayBe commited on
Commit
f83c4be
·
verified ·
1 Parent(s): d8c303a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import torch
 
3
  import gradio as gr
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
@@ -20,22 +21,33 @@ if torch.cuda.is_available():
20
  except:
21
  pass # Ignore if torch.compile is not available
22
 
23
- # Define the function for inference
 
 
 
 
 
 
 
 
 
24
  def generate_command(input_command):
25
  prompt = "extract: " + input_command
26
-
27
- # Encode input and move tensors to the correct device
28
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
29
 
30
- # Generate with optimized settings
31
  output_ids = model.generate(
32
  input_ids,
33
  max_length=64, # Reduced for speed
34
- num_beams=3, # Lowered from 5 to 3 for faster output
35
  early_stopping=True
36
  )
37
 
38
- return tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
 
 
39
 
40
  # Create a Gradio interface
41
  iface = gr.Interface(
 
1
  import json
2
  import torch
3
+ import re
4
  import gradio as gr
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
6
 
 
21
  except:
22
  pass # Ignore if torch.compile is not available
23
 
24
+ # Corrects formatting of amounts, handles commas and precision issues
25
+ def correct_amount_format(output):
26
+ # Fix the commas in decimal amounts
27
+ output = re.sub(r'(\d+),(\d+)', r'\1.\2', output)
28
+
29
+ # Ensure that numbers with more than 2 decimal places are rounded
30
+ output = re.sub(r'(\d+\.\d{2})\d+', r'\1', output) # Keeps only 2 decimal places
31
+ return output
32
+
33
+ # Function to generate command and parse amounts
34
  def generate_command(input_command):
35
  prompt = "extract: " + input_command
 
 
36
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
37
 
 
38
  output_ids = model.generate(
39
  input_ids,
40
  max_length=64, # Reduced for speed
41
+ num_beams=3, # Lowered from 5 to 3 for faster output
42
  early_stopping=True
43
  )
44
 
45
+ result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
46
+
47
+ # Apply the post-processing fix to the amount
48
+ result = correct_amount_format(result)
49
+
50
+ return result
51
 
52
  # Create a Gradio interface
53
  iface = gr.Interface(