RayBe commited on
Commit
7216ccc
·
verified ·
1 Parent(s): a4334ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -26
app.py CHANGED
@@ -1,10 +1,9 @@
1
- import json
2
- import torch
3
  import re
 
4
  import gradio as gr
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
6
 
7
- # Load the fine-tuned model
8
  model_name = "./t5-finetuned-final"
9
  tokenizer = T5Tokenizer.from_pretrained(model_name)
10
  model = T5ForConditionalGeneration.from_pretrained(model_name)
@@ -13,43 +12,59 @@ model = T5ForConditionalGeneration.from_pretrained(model_name)
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
15
 
16
- # Enable optimizations for GPU
17
  if torch.cuda.is_available():
18
  model.half() # Use half-precision for faster computation
19
  try:
20
- model = torch.compile(model) # PyTorch 2.0+ optimization
21
- except:
22
- pass # Ignore if torch.compile is not available
23
 
24
- # Corrects formatting of amounts, handles commas and precision issues
25
- def correct_amount_format(output):
26
- # Fix the commas in decimal amounts
27
- output = re.sub(r'(\d+),(\d+)', r'\1.\2', output)
28
-
29
- # Ensure that numbers with more than 2 decimal places are rounded
30
- output = re.sub(r'(\d+\.\d{2})\d+', r'\1', output) # Keeps only 2 decimal places
31
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Function to generate command and parse amounts
34
  def generate_command(input_command):
35
  prompt = "extract: " + input_command
36
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
37
 
 
 
 
 
38
  output_ids = model.generate(
39
  input_ids,
40
- max_length=64, # Reduced for speed
41
- num_beams=3, # Lowered from 5 to 3 for faster output
42
  early_stopping=True
43
  )
44
-
 
45
  result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
46
 
47
- # Apply the post-processing fix to the amount
48
- result = correct_amount_format(result)
49
-
50
- return result
51
 
52
- # Create a Gradio interface
53
  iface = gr.Interface(
54
  fn=generate_command,
55
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
@@ -58,7 +73,7 @@ iface = gr.Interface(
58
  description="Enter a command, and the fine-tuned T5 model will extract relevant details in JSON format.",
59
  )
60
 
61
- # Launch the app
62
  if __name__ == "__main__":
63
  iface.launch()
64
 
 
 
 
 
1
  import re
2
+ import torch
3
  import gradio as gr
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
6
+ # Load the fine-tuned model from the local folder
7
  model_name = "./t5-finetuned-final"
8
  tokenizer = T5Tokenizer.from_pretrained(model_name)
9
  model = T5ForConditionalGeneration.from_pretrained(model_name)
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model.to(device)
14
 
15
+ # Optimize for GPU: half precision and compilation (if supported)
16
  if torch.cuda.is_available():
17
  model.half() # Use half-precision for faster computation
18
  try:
19
+ model = torch.compile(model) # Optimize with torch.compile() (PyTorch 2.0+)
20
+ except Exception:
21
+ pass # Continue if torch.compile() isn't available
22
 
23
+ def fix_amount_in_output(input_command, output_str):
24
+ """
25
+ This function extracts the first decimal number found in the input_command
26
+ and then replaces the "amount" field in the model output with that number.
27
+ """
28
+ # Extract the first number that has a decimal point (or comma) from the input.
29
+ match = re.search(r'(\d+(?:[.,]\d+))', input_command)
30
+ if match:
31
+ # Normalize to use a period as the decimal separator.
32
+ correct_amount_str = match.group(1).replace(',', '.')
33
+ else:
34
+ # If nothing is found, return the output unchanged.
35
+ return output_str
36
+
37
+ # Replace the amount value in the output.
38
+ # This expects the output to contain a pattern like: "amount": some_number
39
+ fixed_output = re.sub(
40
+ r'("amount"\s*:\s*)(\d+(?:\.\d+)?)',
41
+ r'\1' + correct_amount_str,
42
+ output_str
43
+ )
44
+ return fixed_output
45
 
 
46
  def generate_command(input_command):
47
  prompt = "extract: " + input_command
 
48
 
49
+ # Tokenize input and send to the correct device.
50
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
51
+
52
+ # Generate output using optimized parameters.
53
  output_ids = model.generate(
54
  input_ids,
55
+ max_length=64, # Reduced length for faster generation.
56
+ num_beams=3, # Fewer beams for faster inference.
57
  early_stopping=True
58
  )
59
+
60
+ # Decode the generated tokens.
61
  result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
62
 
63
+ # Fix the "amount" field in the output using the input value.
64
+ result_fixed = fix_amount_in_output(input_command, result)
65
+ return result_fixed
 
66
 
67
+ # Define a Gradio interface.
68
  iface = gr.Interface(
69
  fn=generate_command,
70
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
 
73
  description="Enter a command, and the fine-tuned T5 model will extract relevant details in JSON format.",
74
  )
75
 
 
76
  if __name__ == "__main__":
77
  iface.launch()
78
 
79
+