RayBe commited on
Commit
e89ec39
·
verified ·
1 Parent(s): 7216ccc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -28
app.py CHANGED
@@ -3,39 +3,37 @@ import torch
3
  import gradio as gr
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
6
- # Load the fine-tuned model from the local folder
7
  model_name = "./t5-finetuned-final"
8
  tokenizer = T5Tokenizer.from_pretrained(model_name)
9
  model = T5ForConditionalGeneration.from_pretrained(model_name)
10
 
11
- # Move model to GPU if available
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model.to(device)
14
 
15
- # Optimize for GPU: half precision and compilation (if supported)
16
  if torch.cuda.is_available():
17
- model.half() # Use half-precision for faster computation
18
  try:
19
- model = torch.compile(model) # Optimize with torch.compile() (PyTorch 2.0+)
20
  except Exception:
21
- pass # Continue if torch.compile() isn't available
22
 
23
  def fix_amount_in_output(input_command, output_str):
24
  """
25
- This function extracts the first decimal number found in the input_command
26
- and then replaces the "amount" field in the model output with that number.
27
  """
28
- # Extract the first number that has a decimal point (or comma) from the input.
29
  match = re.search(r'(\d+(?:[.,]\d+))', input_command)
30
  if match:
31
- # Normalize to use a period as the decimal separator.
32
  correct_amount_str = match.group(1).replace(',', '.')
33
  else:
34
- # If nothing is found, return the output unchanged.
35
  return output_str
36
 
37
- # Replace the amount value in the output.
38
- # This expects the output to contain a pattern like: "amount": some_number
39
  fixed_output = re.sub(
40
  r'("amount"\s*:\s*)(\d+(?:\.\d+)?)',
41
  r'\1' + correct_amount_str,
@@ -45,26 +43,29 @@ def fix_amount_in_output(input_command, output_str):
45
 
46
  def generate_command(input_command):
47
  prompt = "extract: " + input_command
48
-
49
- # Tokenize input and send to the correct device.
50
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
51
-
52
- # Generate output using optimized parameters.
53
- output_ids = model.generate(
54
- input_ids,
55
- max_length=64, # Reduced length for faster generation.
56
- num_beams=3, # Fewer beams for faster inference.
57
- early_stopping=True
58
- )
59
-
60
- # Decode the generated tokens.
61
- result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
62
 
63
- # Fix the "amount" field in the output using the input value.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  result_fixed = fix_amount_in_output(input_command, result)
65
  return result_fixed
66
 
67
- # Define a Gradio interface.
68
  iface = gr.Interface(
69
  fn=generate_command,
70
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
@@ -77,3 +78,4 @@ if __name__ == "__main__":
77
  iface.launch()
78
 
79
 
 
 
3
  import gradio as gr
4
  from transformers import T5Tokenizer, T5ForConditionalGeneration
5
 
6
+ # Load the fine-tuned model and tokenizer from the local folder
7
  model_name = "./t5-finetuned-final"
8
  tokenizer = T5Tokenizer.from_pretrained(model_name)
9
  model = T5ForConditionalGeneration.from_pretrained(model_name)
10
 
11
+ # Move model to GPU if available; otherwise, it will run on CPU
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  model.to(device)
14
 
15
+ # If using GPU, enable half precision and try torch.compile (PyTorch 2.0+)
16
  if torch.cuda.is_available():
17
+ model.half() # Use half-precision for faster computation on GPU
18
  try:
19
+ model = torch.compile(model)
20
  except Exception:
21
+ pass # Continue if torch.compile is unavailable
22
 
23
  def fix_amount_in_output(input_command, output_str):
24
  """
25
+ Extracts the first decimal number from the input and replaces the "amount" value
26
+ in the output with that exact value.
27
  """
28
+ # Look for a number with optional decimal separator in the input command.
29
  match = re.search(r'(\d+(?:[.,]\d+))', input_command)
30
  if match:
31
+ # Normalize any commas to a period.
32
  correct_amount_str = match.group(1).replace(',', '.')
33
  else:
 
34
  return output_str
35
 
36
+ # Replace the "amount" value in the output with the extracted amount.
 
37
  fixed_output = re.sub(
38
  r'("amount"\s*:\s*)(\d+(?:\.\d+)?)',
39
  r'\1' + correct_amount_str,
 
43
 
44
  def generate_command(input_command):
45
  prompt = "extract: " + input_command
 
 
46
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ # Use greedy decoding (num_beams=1) on CPU for speed; otherwise, use beam search on GPU.
49
+ if device.type == "cpu":
50
+ output_ids = model.generate(
51
+ input_ids,
52
+ max_length=64,
53
+ num_beams=1, # Greedy decoding for faster output on CPU
54
+ early_stopping=True
55
+ )
56
+ else:
57
+ output_ids = model.generate(
58
+ input_ids,
59
+ max_length=64,
60
+ num_beams=3, # Beam search for potentially higher quality on GPU
61
+ early_stopping=True
62
+ )
63
+
64
+ result = tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
  result_fixed = fix_amount_in_output(input_command, result)
66
  return result_fixed
67
 
68
+ # Create a Gradio interface.
69
  iface = gr.Interface(
70
  fn=generate_command,
71
  inputs=gr.Textbox(lines=2, placeholder="Enter a command..."),
 
78
  iface.launch()
79
 
80
 
81
+