RayBe commited on
Commit
a09d2a8
·
verified ·
1 Parent(s): e89ec39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -49
app.py CHANGED
@@ -1,71 +1,74 @@
1
- import re
2
  import torch
 
3
  import gradio as gr
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
6
- # Load the fine-tuned model and tokenizer from the local folder
7
  model_name = "./t5-finetuned-final"
8
  tokenizer = T5Tokenizer.from_pretrained(model_name)
9
  model = T5ForConditionalGeneration.from_pretrained(model_name)
10
 
11
- # Move model to GPU if available; otherwise, it will run on CPU
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model.to(device)
14
 
15
- # If using GPU, enable half precision and try torch.compile (PyTorch 2.0+)
16
  if torch.cuda.is_available():
17
- model.half() # Use half-precision for faster computation on GPU
18
  try:
19
- model = torch.compile(model)
20
- except Exception:
21
- pass # Continue if torch.compile is unavailable
22
 
23
- def fix_amount_in_output(input_command, output_str):
24
  """
25
- Extracts the first decimal number from the input and replaces the "amount" value
26
- in the output with that exact value.
27
  """
28
- # Look for a number with optional decimal separator in the input command.
29
- match = re.search(r'(\d+(?:[.,]\d+))', input_command)
30
- if match:
31
- # Normalize any commas to a period.
32
- correct_amount_str = match.group(1).replace(',', '.')
33
- else:
34
- return output_str
35
 
36
- # Replace the "amount" value in the output with the extracted amount.
37
- fixed_output = re.sub(
38
- r'("amount"\s*:\s*)(\d+(?:\.\d+)?)',
39
- r'\1' + correct_amount_str,
40
- output_str
41
- )
42
- return fixed_output
43
 
44
  def generate_command(input_command):
 
 
 
45
  prompt = "extract: " + input_command
46
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
47
-
48
- # Use greedy decoding (num_beams=1) on CPU for speed; otherwise, use beam search on GPU.
49
- if device.type == "cpu":
50
- output_ids = model.generate(
51
- input_ids,
52
- max_length=64,
53
- num_beams=1, # Greedy decoding for faster output on CPU
54
- early_stopping=True
55
- )
56
- else:
57
- output_ids = model.generate(
58
- input_ids,
59
- max_length=64,
60
- num_beams=3, # Beam search for potentially higher quality on GPU
61
- early_stopping=True
62
- )
63
-
64
  result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
- result_fixed = fix_amount_in_output(input_command, result)
66
- return result_fixed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- # Create a Gradio interface.
69
  iface = gr.Interface(
70
  fn=generate_command,
71
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
@@ -74,8 +77,6 @@ iface = gr.Interface(
74
  description="Enter a command, and the fine-tuned T5 model will extract relevant details in JSON format.",
75
  )
76
 
 
77
  if __name__ == "__main__":
78
- iface.launch()
79
-
80
-
81
-
 
1
+ import json
2
  import torch
3
+ import re
4
  import gradio as gr
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
6
 
7
+ # Load the fine-tuned model
8
  model_name = "./t5-finetuned-final"
9
  tokenizer = T5Tokenizer.from_pretrained(model_name)
10
  model = T5ForConditionalGeneration.from_pretrained(model_name)
11
 
12
+ # Move model to GPU if available
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
15
 
16
+ # Enable optimizations for GPU
17
  if torch.cuda.is_available():
18
+ model.half() # Use half-precision for faster computation
19
  try:
20
+ model = torch.compile(model) # PyTorch 2.0+ optimization
21
+ except:
22
+ pass # Ignore if torch.compile is not available
23
 
24
+ def sanitize_amount(output):
25
  """
26
+ Sanitizes the amount field to ensure it is correctly formatted.
 
27
  """
28
+ # Fix malformed amounts like "46307.0" -> "4630.07" or "4630327.0" -> "463032.07"
29
+ def fix_malformed_amount(match):
30
+ full_match = match.group(0)
31
+ integer_part = match.group(1)
32
+ decimal_part = match.group(2)
33
+ return f"{integer_part}.{decimal_part}" # Reconstruct the correct format
 
34
 
35
+ # Match numbers with misplaced decimal points
36
+ output = re.sub(r'(\d+)(\d{2})\.0', fix_malformed_amount, output)
37
+ return output
 
 
 
 
38
 
39
  def generate_command(input_command):
40
+ """
41
+ Generates the command and ensures the exact amount is displayed without changes.
42
+ """
43
  prompt = "extract: " + input_command
44
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
45
+
46
+ # Generate output from the model
47
+ output_ids = model.generate(
48
+ input_ids,
49
+ max_length=64, # Reduced for speed
50
+ num_beams=3, # Lowered from 5 to 3 for faster output
51
+ early_stopping=True
52
+ )
 
 
 
 
 
 
 
 
 
53
  result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
54
+
55
+ # Sanitize the output to fix malformed amounts
56
+ sanitized_result = sanitize_amount(result)
57
+
58
+ try:
59
+ # Attempt to parse the sanitized result as JSON
60
+ data = json.loads(sanitized_result)
61
+
62
+ # Convert numeric amounts to strings to preserve exact formatting
63
+ if isinstance(data.get("amount"), (int, float)):
64
+ data["amount"] = str(data["amount"])
65
+
66
+ return json.dumps(data, ensure_ascii=False) # Return as JSON string
67
+ except json.JSONDecodeError:
68
+ # If not valid JSON, return the raw sanitized output
69
+ return sanitized_result
70
 
71
+ # Create a Gradio interface
72
  iface = gr.Interface(
73
  fn=generate_command,
74
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
 
77
  description="Enter a command, and the fine-tuned T5 model will extract relevant details in JSON format.",
78
  )
79
 
80
+ # Launch the app
81
  if __name__ == "__main__":
82
+ iface.launch()