jing-ju commited on
Commit
0fdedb6
·
verified ·
1 Parent(s): 70a3cb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -41
app.py CHANGED
@@ -87,26 +87,39 @@ def load_model():
87
  trust_remote_code=True
88
  )
89
 
90
- # Create quantization config for fp8
91
  try:
92
- from transformers.quantizers import CompressedTensorsQuantizationConfig
93
- quantization_config = CompressedTensorsQuantizationConfig(
94
  quantization_method="fp8",
95
  ignore=[]
96
  )
 
97
  except ImportError:
98
- # Fallback to dict format
99
- quantization_config = {
100
- "quantization_method": "fp8",
101
- "ignore": []
102
- }
 
 
 
 
 
 
103
 
104
  # Load model with quantization config
 
 
 
 
 
 
 
 
105
  model = AutoModelForCausalLM.from_pretrained(
106
  MODEL_NAME,
107
- trust_remote_code=True,
108
- quantization_config=quantization_config,
109
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
110
  )
111
 
112
  return tokenizer, model
@@ -141,14 +154,17 @@ def chunk_text_by_tokens(text, tokenizer, max_tokens):
141
  chunks.append(current_chunk.strip())
142
 
143
  # If single sentence is too long, split it forcefully
144
- if len(tokenizer.encode(sentence, add_special_tokens=False)) > max_tokens:
145
- tokens = tokenizer.encode(sentence, add_special_tokens=False)
146
- for i in range(0, len(tokens), max_tokens):
147
- chunk_tokens = tokens[i:i + max_tokens]
148
- chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
149
- chunks.append(chunk_text)
150
- current_chunk = ""
151
- else:
 
 
 
152
  current_chunk = sentence
153
 
154
  if current_chunk:
@@ -171,12 +187,16 @@ def translate_text_chunk(text, target_lang, source_lang, tokenizer, model):
171
  prompt = f"Translate the following segment into {target_lang}, without additional explanation.\n\n{text}"
172
 
173
  # Apply chat template
174
- messages = [{"role": "user", "content": prompt}]
175
- input_text = tokenizer.apply_chat_template(
176
- messages,
177
- tokenize=False,
178
- add_generation_prompt=True
179
- )
 
 
 
 
180
 
181
  # Tokenize
182
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
@@ -186,7 +206,7 @@ def translate_text_chunk(text, target_lang, source_lang, tokenizer, model):
186
  outputs = model.generate(
187
  **inputs,
188
  **GEN_KW,
189
- pad_token_id=tokenizer.eos_token_id
190
  )
191
 
192
  # Decode
@@ -245,9 +265,21 @@ def translate_batch(text_lines, target_lang, source_lang, tokenizer, model):
245
 
246
  # Load model and tokenizer
247
  print("Initializing model...")
248
- tokenizer, model = load_model()
249
- device = model.device
250
- print(f"Model loaded on device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  # Create Gradio interface
253
  with gr.Blocks(title="Hunyuan-MT Multi-language Translation") as demo:
@@ -282,12 +314,20 @@ with gr.Blocks(title="Hunyuan-MT Multi-language Translation") as demo:
282
  interactive=False
283
  )
284
 
285
- translate_btn.click(
286
- fn=lambda text, tgt, src: translate_single(text, tgt, src, tokenizer, model),
287
- inputs=[input_text, target_lang, source_lang],
288
- outputs=output_text,
289
- api_name="translate_text"
290
- )
 
 
 
 
 
 
 
 
291
 
292
  with gr.TabItem("Batch Translation"):
293
  with gr.Row():
@@ -315,12 +355,20 @@ with gr.Blocks(title="Hunyuan-MT Multi-language Translation") as demo:
315
  interactive=False
316
  )
317
 
318
- batch_translate_btn.click(
319
- fn=lambda text, tgt, src: translate_batch(text, tgt, src, tokenizer, model),
320
- inputs=[batch_input, batch_target_lang, batch_source_lang],
321
- outputs=batch_output,
322
- api_name="translate_batch"
323
- )
 
 
 
 
 
 
 
 
324
 
325
  gr.Markdown("### API Usage")
326
  gr.Markdown("""
 
87
  trust_remote_code=True
88
  )
89
 
90
+ # Create quantization config for fp8 - must use the actual class
91
  try:
92
+ from compressed_tensors import CompressedTensorsConfig
93
+ quantization_config = CompressedTensorsConfig(
94
  quantization_method="fp8",
95
  ignore=[]
96
  )
97
+ print("Using CompressedTensorsConfig")
98
  except ImportError:
99
+ try:
100
+ from transformers.quantizers import CompressedTensorsQuantizationConfig
101
+ quantization_config = CompressedTensorsQuantizationConfig(
102
+ quantization_method="fp8",
103
+ ignore=[]
104
+ )
105
+ print("Using CompressedTensorsQuantizationConfig")
106
+ except ImportError:
107
+ # If both fail, load without custom quantization config
108
+ print("Loading model without custom quantization config")
109
+ quantization_config = None
110
 
111
  # Load model with quantization config
112
+ model_kwargs = {
113
+ "trust_remote_code": True,
114
+ "dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
115
+ }
116
+
117
+ if quantization_config is not None:
118
+ model_kwargs["quantization_config"] = quantization_config
119
+
120
  model = AutoModelForCausalLM.from_pretrained(
121
  MODEL_NAME,
122
+ **model_kwargs
 
 
123
  )
124
 
125
  return tokenizer, model
 
154
  chunks.append(current_chunk.strip())
155
 
156
  # If single sentence is too long, split it forcefully
157
+ try:
158
+ sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
159
+ if len(sentence_tokens) > max_tokens:
160
+ for i in range(0, len(sentence_tokens), max_tokens):
161
+ chunk_tokens = sentence_tokens[i:i + max_tokens]
162
+ chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
163
+ chunks.append(chunk_text)
164
+ current_chunk = ""
165
+ else:
166
+ current_chunk = sentence
167
+ except:
168
  current_chunk = sentence
169
 
170
  if current_chunk:
 
187
  prompt = f"Translate the following segment into {target_lang}, without additional explanation.\n\n{text}"
188
 
189
  # Apply chat template
190
+ try:
191
+ messages = [{"role": "user", "content": prompt}]
192
+ input_text = tokenizer.apply_chat_template(
193
+ messages,
194
+ tokenize=False,
195
+ add_generation_prompt=True
196
+ )
197
+ except:
198
+ # Fallback if chat template fails
199
+ input_text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
200
 
201
  # Tokenize
202
  inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
 
206
  outputs = model.generate(
207
  **inputs,
208
  **GEN_KW,
209
+ pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else tokenizer.pad_token_id
210
  )
211
 
212
  # Decode
 
265
 
266
  # Load model and tokenizer
267
  print("Initializing model...")
268
+ try:
269
+ tokenizer, model = load_model()
270
+ device = model.device
271
+ print(f"Model loaded successfully on device: {device}")
272
+ except Exception as e:
273
+ print(f"Error loading model: {e}")
274
+ # Create dummy functions for interface
275
+ tokenizer = None
276
+ model = None
277
+
278
+ def dummy_translate(text, target_lang, source_lang):
279
+ return f"Model loading failed: {e}"
280
+
281
+ translate_single = dummy_translate
282
+ translate_batch = lambda text_lines, target_lang, source_lang, *args: dummy_translate(text_lines, target_lang, source_lang)
283
 
284
  # Create Gradio interface
285
  with gr.Blocks(title="Hunyuan-MT Multi-language Translation") as demo:
 
314
  interactive=False
315
  )
316
 
317
+ if tokenizer and model:
318
+ translate_btn.click(
319
+ fn=lambda text, tgt, src: translate_single(text, tgt, src, tokenizer, model),
320
+ inputs=[input_text, target_lang, source_lang],
321
+ outputs=output_text,
322
+ api_name="translate_text"
323
+ )
324
+ else:
325
+ translate_btn.click(
326
+ fn=lambda text, tgt, src: translate_single(text, tgt, src),
327
+ inputs=[input_text, target_lang, source_lang],
328
+ outputs=output_text,
329
+ api_name="translate_text"
330
+ )
331
 
332
  with gr.TabItem("Batch Translation"):
333
  with gr.Row():
 
355
  interactive=False
356
  )
357
 
358
+ if tokenizer and model:
359
+ batch_translate_btn.click(
360
+ fn=lambda text, tgt, src: translate_batch(text, tgt, src, tokenizer, model),
361
+ inputs=[batch_input, batch_target_lang, batch_source_lang],
362
+ outputs=batch_output,
363
+ api_name="translate_batch"
364
+ )
365
+ else:
366
+ batch_translate_btn.click(
367
+ fn=lambda text, tgt, src: translate_batch(text, tgt, src),
368
+ inputs=[batch_input, batch_target_lang, batch_source_lang],
369
+ outputs=batch_output,
370
+ api_name="translate_batch"
371
+ )
372
 
373
  gr.Markdown("### API Usage")
374
  gr.Markdown("""