saeedabdulmuizz commited on
Commit
48cb3b4
·
verified ·
1 Parent(s): edd0ea2

time optimization

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -87,9 +87,14 @@ def load_translation_models():
87
  # Load the LoRA adapter
88
  print("[*] Loading LoRA adapter...")
89
  model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
 
 
 
 
 
90
  model.eval()
91
 
92
- print(f"[+] Translation model loaded successfully on CPU.")
93
  _trans_cache["tokenizer"] = tokenizer
94
  _trans_cache["model"] = model
95
  _trans_cache["loaded"] = True
@@ -137,13 +142,15 @@ def _translate_impl(text):
137
  start_time = time.time()
138
  print("[DEBUG] Starting generation...")
139
 
140
- # Generation settings matching evaluate_model.py for Sarvam
 
 
141
  with torch.no_grad():
142
  generated = trans_model.generate(
143
  **inputs,
144
- max_new_tokens=512,
145
- do_sample=True,
146
- temperature=0.01,
147
  )
148
 
149
  elapsed = time.time() - start_time
@@ -182,13 +189,7 @@ def process(text, speaker_id):
182
  # Filter out any non-integer values (unknown characters not in vocabulary)
183
  # This happens when text contains characters not supported by the TTS model
184
  filtered_sequence = [s for s in sequence if isinstance(s, int)]
185
-
186
- if not filtered_sequence:
187
- raise ValueError("No valid characters found in input text for TTS model.")
188
-
189
- if len(filtered_sequence) != len(sequence):
190
- print(f"[WARN] Filtered out {len(sequence) - len(filtered_sequence)} unknown characters from TTS input")
191
-
192
  x = torch.tensor(intersperse(filtered_sequence, 0), dtype=torch.long, device=DEVICE)[None]
193
  x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=DEVICE)
194
 
 
87
  # Load the LoRA adapter
88
  print("[*] Loading LoRA adapter...")
89
  model = PeftModel.from_pretrained(base_model, TRANSLATION_ADAPTER)
90
+
91
+ # Merge LoRA weights into base model for faster inference
92
+ # This eliminates adapter overhead during generation
93
+ print("[*] Merging LoRA weights for faster inference...")
94
+ model = model.merge_and_unload()
95
  model.eval()
96
 
97
+ print(f"[+] Translation model loaded and merged successfully on CPU.")
98
  _trans_cache["tokenizer"] = tokenizer
99
  _trans_cache["model"] = model
100
  _trans_cache["loaded"] = True
 
142
  start_time = time.time()
143
  print("[DEBUG] Starting generation...")
144
 
145
+ # Generation settings optimized for CPU inference
146
+ # - Greedy decoding (do_sample=False) is faster than sampling
147
+ # - Same quality as temp=0.01 which was near-greedy anyway
148
  with torch.no_grad():
149
  generated = trans_model.generate(
150
  **inputs,
151
+ max_new_tokens=256, # Keep full length for long texts
152
+ do_sample=False, # Greedy decoding for speed
153
+ num_beams=1, # No beam search overhead
154
  )
155
 
156
  elapsed = time.time() - start_time
 
189
  # Filter out any non-integer values (unknown characters not in vocabulary)
190
  # This happens when text contains characters not supported by the TTS model
191
  filtered_sequence = [s for s in sequence if isinstance(s, int)]
192
+
 
 
 
 
 
 
193
  x = torch.tensor(intersperse(filtered_sequence, 0), dtype=torch.long, device=DEVICE)[None]
194
  x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=DEVICE)
195