RayBe commited on
Commit
74e1a78
·
verified ·
1 Parent(s): 5ddb235

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -40
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
- import torch
3
  import re
 
4
  import gradio as gr
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
6
 
@@ -9,74 +9,65 @@ model_name = "./t5-finetuned-final"
9
  tokenizer = T5Tokenizer.from_pretrained(model_name)
10
  model = T5ForConditionalGeneration.from_pretrained(model_name)
11
 
12
- # Move model to GPU if available
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
15
 
16
- # Enable optimizations for GPU
17
  if torch.cuda.is_available():
18
- model.half() # Use half-precision for faster computation
19
  try:
20
- model = torch.compile(model) # PyTorch 2.0+ optimization
21
- except Exception:
22
- pass # Ignore if torch.compile is not available
23
 
24
  def correct_amount_format(output):
25
- """
26
- This function attempts to fix the numeric formatting issues in the generated output:
27
- 1. It replaces a comma used as a decimal separator (i.e. followed by exactly two digits) with a period.
28
- 2. It converts the number to a float and rounds it to two decimal places.
29
-
30
- If the output is valid JSON, it will update the "amount" field accordingly.
31
- Otherwise, it falls back to a regex-based fix.
32
- """
33
  try:
34
- # Try to parse the output as JSON
35
  data = json.loads(output)
36
- if "amount" in data and isinstance(data["amount"], str):
37
- # Replace a comma that is likely a decimal separator (e.g., "10,50" -> "10.50")
38
- amount_str = re.sub(r'(\d+),(\d{2})\b', r'\1.\2', data["amount"])
39
- try:
40
- # Convert to float, round to two decimals, then reformat
41
- num = float(amount_str)
42
- rounded = round(num, 2)
43
- data["amount"] = "{:.2f}".format(rounded)
44
- except ValueError:
45
- # If conversion fails, leave the original value
46
- pass
 
 
 
 
 
 
47
  return json.dumps(data, ensure_ascii=False)
48
  except json.JSONDecodeError:
49
- # Fallback if output is not valid JSON:
50
- # Replace commas used as decimal separators (only if followed by exactly 2 digits)
51
- output = re.sub(r'(\d+),(\d{2})\b', r'\1.\2', output)
52
- # Fallback: truncate any extra digits (note: this does not round)
53
- output = re.sub(r'(\d+\.\d{2})\d+', r'\1', output)
54
  return output
55
 
56
  def generate_command(input_command):
57
  prompt = "extract: " + input_command
58
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
59
-
60
  output_ids = model.generate(
61
  input_ids,
62
- max_length=64, # Reduced for speed
63
- num_beams=3, # Lowered from 5 to 3 for faster output
64
  early_stopping=True
65
  )
66
-
67
  result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
68
- # Apply the updated post-processing to correct the amount formatting
69
  result = correct_amount_format(result)
70
  return result
71
 
72
- # Create a Gradio interface
73
  iface = gr.Interface(
74
  fn=generate_command,
75
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
76
  outputs=gr.Textbox(label="Extracted JSON Output"),
77
  title="T5 Fine-Tuned Command Extractor",
78
- description="Enter a command, and the fine-tuned T5 model will extract relevant details in JSON format.",
79
  )
80
 
81
  if __name__ == "__main__":
82
- iface.launch()
 
1
  import json
 
2
  import re
3
+ import torch
4
  import gradio as gr
5
  from transformers import T5Tokenizer, T5ForConditionalGeneration
6
 
 
9
  tokenizer = T5Tokenizer.from_pretrained(model_name)
10
  model = T5ForConditionalGeneration.from_pretrained(model_name)
11
 
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model.to(device)
14
 
 
15
  if torch.cuda.is_available():
16
+ model.half()
17
  try:
18
+ model = torch.compile(model)
19
+ except:
20
+ pass
21
 
22
  def correct_amount_format(output):
23
+ # Attempt to parse as JSON and correct amounts
 
 
 
 
 
 
 
24
  try:
 
25
  data = json.loads(output)
26
+
27
+ def correct_value(value):
28
+ if isinstance(value, str):
29
+ # Remove commas used as thousand separators
30
+ value = re.sub(r',(?=\d{3})', '', value)
31
+ # Replace the first comma with a period (decimal)
32
+ value = value.replace(',', '.', 1)
33
+ return value
34
+
35
+ # Correct each value in the JSON data
36
+ if isinstance(data, dict):
37
+ for key in data:
38
+ data[key] = correct_value(data[key])
39
+ elif isinstance(data, list):
40
+ for i in range(len(data)):
41
+ data[i] = correct_value(data[i])
42
+
43
  return json.dumps(data, ensure_ascii=False)
44
  except json.JSONDecodeError:
45
+ # Fallback for invalid JSON: basic corrections
46
+ output = re.sub(r'(\d+),(\d+)\b', r'\1.\2', output)
 
 
 
47
  return output
48
 
49
  def generate_command(input_command):
50
  prompt = "extract: " + input_command
51
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
52
+
53
  output_ids = model.generate(
54
  input_ids,
55
+ max_length=64,
56
+ num_beams=3,
57
  early_stopping=True
58
  )
59
+
60
  result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
61
  result = correct_amount_format(result)
62
  return result
63
 
 
64
  iface = gr.Interface(
65
  fn=generate_command,
66
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
67
  outputs=gr.Textbox(label="Extracted JSON Output"),
68
  title="T5 Fine-Tuned Command Extractor",
69
+ description="Extracts details in JSON format with exact amount preservation.",
70
  )
71
 
72
  if __name__ == "__main__":
73
+ iface.launch()