saeedabdulmuizz commited on
Commit
dda3deb
·
verified ·
1 Parent(s): d8bac84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -11
app.py CHANGED
@@ -88,23 +88,55 @@ def translate(text):
88
 
89
  try:
90
  # Note: apply_chat_template returns input_ids tensor directly if tokenize=True and return_tensors="pt"
91
- input_ids = trans_tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(trans_model.device)
 
 
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
  print(f"Chat template error: {e}")
 
94
  return "Error in translation template."
95
 
96
- with torch.no_grad():
97
- # Use greedy decoding (do_sample=False) to avoid NaN/Inf issues with float16 sampling
98
- outputs = trans_model.generate(
99
- input_ids,
100
- max_new_tokens=256,
101
- do_sample=False, # Greedy decoding avoids multinomial NaN errors
102
- pad_token_id=trans_tokenizer.pad_token_id,
103
- eos_token_id=trans_tokenizer.eos_token_id,
104
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  # Slice reusing the input length
107
- decoded = trans_tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
 
 
 
 
 
108
  return decoded.strip()
109
 
110
 
 
88
 
89
  try:
90
  # Note: apply_chat_template returns input_ids tensor directly if tokenize=True and return_tensors="pt"
91
+ input_ids = trans_tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
92
+
93
+ # Debug: Check devices
94
+ print(f"[DEBUG] Input device: {input_ids.device}")
95
+ print(f"[DEBUG] Model device: {trans_model.device}")
96
+ print(f"[DEBUG] Input shape: {input_ids.shape}")
97
+ print(f"[DEBUG] Input tokens: {input_ids.shape[1]}")
98
+
99
+ # Move input to model's device
100
+ input_ids = input_ids.to(trans_model.device)
101
+ print(f"[DEBUG] Input moved to: {input_ids.device}")
102
+
103
  except Exception as e:
104
  print(f"Chat template error: {e}")
105
+ traceback.print_exc()
106
  return "Error in translation template."
107
 
108
+ try:
109
+ import time
110
+ start_time = time.time()
111
+ print("[DEBUG] Starting generation...")
112
+
113
+ with torch.no_grad():
114
+ # Use greedy decoding (do_sample=False) to avoid NaN/Inf issues with float16 sampling
115
+ outputs = trans_model.generate(
116
+ input_ids,
117
+ max_new_tokens=128, # Reduced for faster generation
118
+ do_sample=False, # Greedy decoding avoids multinomial NaN errors
119
+ pad_token_id=trans_tokenizer.pad_token_id,
120
+ eos_token_id=trans_tokenizer.eos_token_id,
121
+ )
122
+
123
+ elapsed = time.time() - start_time
124
+ print(f"[DEBUG] Generation completed in {elapsed:.2f}s")
125
+ print(f"[DEBUG] Output shape: {outputs.shape}")
126
+ print(f"[DEBUG] New tokens generated: {outputs.shape[1] - input_ids.shape[1]}")
127
+
128
+ except Exception as e:
129
+ print(f"Generation error: {e}")
130
+ traceback.print_exc()
131
+ return "Error during translation generation."
132
 
133
  # Slice reusing the input length
134
+ new_tokens = outputs[0][input_ids.shape[1]:]
135
+ print(f"[DEBUG] New tokens to decode: {len(new_tokens)}")
136
+
137
+ decoded = trans_tokenizer.decode(new_tokens, skip_special_tokens=True)
138
+ print(f"[DEBUG] Decoded output: '{decoded}'")
139
+
140
  return decoded.strip()
141
 
142