pentarosarium commited on
Commit
024b277
·
1 Parent(s): 778345e

replace Helsinki

Browse files
Files changed (1) hide show
  1. app.py +72 -10
app.py CHANGED
@@ -294,18 +294,80 @@ class EventDetector:
294
  logger.info(f"Initializing models on device: {device}")
295
 
296
  """Initialize all models with GPU support"""
297
- # Initialize translation model
298
- self.translator = pipeline(
299
- "translation",
300
- model="Helsinki-NLP/opus-mt-ru-en",
301
- device=device
302
- )
303
 
304
- self.rutranslator = pipeline(
305
- "translation",
306
- model="Helsinki-NLP/opus-mt-en-ru",
307
- device=device
 
 
308
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
  # Initialize sentiment models
311
  self.finbert = pipeline(
 
294
  logger.info(f"Initializing models on device: {device}")
295
 
296
  """Initialize all models with GPU support"""
297
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
 
 
 
 
298
 
299
+ logger.info("replacing Helsinki-NLP due to conflict with PyTorch version)")
300
+ self.translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
301
+ self.translator_model = AutoModelForSeq2SeqLM.from_pretrained(
302
+ "facebook/nllb-200-distilled-600M",
303
+ use_safetensors=True,
304
+ device_map=device
305
  )
306
+
307
+ # Create custom translation function for ru→en
308
+ def translate_ru_en(text_list):
309
+ if not isinstance(text_list, list):
310
+ text_list = [text_list]
311
+
312
+ results = []
313
+ for text in text_list:
314
+ if not text:
315
+ results.append({"translation_text": ""})
316
+ continue
317
+
318
+ # Prepare input
319
+ inputs = self.translator_tokenizer(text, return_tensors="pt").to(device)
320
+ inputs["forced_bos_token_id"] = self.translator_tokenizer.lang_code_to_id["eng_Latn"]
321
+
322
+ # Generate translation
323
+ with torch.no_grad():
324
+ outputs = self.translator_model.generate(
325
+ **inputs,
326
+ forced_bos_token_id=self.translator_tokenizer.lang_code_to_id["eng_Latn"],
327
+ max_length=512,
328
+ )
329
+
330
+ # Decode and format like Helsinki-NLP output
331
+ translation = self.translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
332
+ results.append({"translation_text": translation})
333
+
334
+ return results
335
+
336
+ # Create custom translation function for en→ru
337
+ def translate_en_ru(text_list):
338
+ if not isinstance(text_list, list):
339
+ text_list = [text_list]
340
+
341
+ results = []
342
+ for text in text_list:
343
+ if not text:
344
+ results.append({"translation_text": ""})
345
+ continue
346
+
347
+ # Prepare input
348
+ inputs = self.translator_tokenizer(text, return_tensors="pt").to(device)
349
+ inputs["forced_bos_token_id"] = self.translator_tokenizer.lang_code_to_id["rus_Cyrl"]
350
+
351
+ # Generate translation
352
+ with torch.no_grad():
353
+ outputs = self.translator_model.generate(
354
+ **inputs,
355
+ forced_bos_token_id=self.translator_tokenizer.lang_code_to_id["rus_Cyrl"],
356
+ max_length=512,
357
+ )
358
+
359
+ # Decode and format like Helsinki-NLP output
360
+ translation = self.translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
361
+ results.append({"translation_text": translation})
362
+
363
+ return results
364
+
365
+ # Replace pipeline with custom functions that mimic the original API
366
+ self.translator = translate_ru_en
367
+ self.rutranslator = translate_en_ru
368
+
369
+ logger.info("Translation models replaced successfully!")
370
+
371
 
372
  # Initialize sentiment models
373
  self.finbert = pipeline(