Spaces:
Build error
Build error
Commit ·
024b277
1
Parent(s): 778345e
replace Helsinki
Browse files
app.py
CHANGED
|
@@ -294,18 +294,80 @@ class EventDetector:
|
|
| 294 |
logger.info(f"Initializing models on device: {device}")
|
| 295 |
|
| 296 |
"""Initialize all models with GPU support"""
|
| 297 |
-
|
| 298 |
-
self.translator = pipeline(
|
| 299 |
-
"translation",
|
| 300 |
-
model="Helsinki-NLP/opus-mt-ru-en",
|
| 301 |
-
device=device
|
| 302 |
-
)
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
|
|
|
|
|
|
| 308 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
# Initialize sentiment models
|
| 311 |
self.finbert = pipeline(
|
|
|
|
| 294 |
logger.info(f"Initializing models on device: {device}")
|
| 295 |
|
| 296 |
"""Initialize all models with GPU support"""
|
| 297 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
+
logger.info("replacing Helsinki-NLP due to conflict with PyTorch version)")
|
| 300 |
+
self.translator_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
| 301 |
+
self.translator_model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 302 |
+
"facebook/nllb-200-distilled-600M",
|
| 303 |
+
use_safetensors=True,
|
| 304 |
+
device_map=device
|
| 305 |
)
|
| 306 |
+
|
| 307 |
+
# Create custom translation function for ru→en
|
| 308 |
+
def translate_ru_en(text_list):
|
| 309 |
+
if not isinstance(text_list, list):
|
| 310 |
+
text_list = [text_list]
|
| 311 |
+
|
| 312 |
+
results = []
|
| 313 |
+
for text in text_list:
|
| 314 |
+
if not text:
|
| 315 |
+
results.append({"translation_text": ""})
|
| 316 |
+
continue
|
| 317 |
+
|
| 318 |
+
# Prepare input
|
| 319 |
+
inputs = self.translator_tokenizer(text, return_tensors="pt").to(device)
|
| 320 |
+
inputs["forced_bos_token_id"] = self.translator_tokenizer.lang_code_to_id["eng_Latn"]
|
| 321 |
+
|
| 322 |
+
# Generate translation
|
| 323 |
+
with torch.no_grad():
|
| 324 |
+
outputs = self.translator_model.generate(
|
| 325 |
+
**inputs,
|
| 326 |
+
forced_bos_token_id=self.translator_tokenizer.lang_code_to_id["eng_Latn"],
|
| 327 |
+
max_length=512,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Decode and format like Helsinki-NLP output
|
| 331 |
+
translation = self.translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 332 |
+
results.append({"translation_text": translation})
|
| 333 |
+
|
| 334 |
+
return results
|
| 335 |
+
|
| 336 |
+
# Create custom translation function for en→ru
|
| 337 |
+
def translate_en_ru(text_list):
|
| 338 |
+
if not isinstance(text_list, list):
|
| 339 |
+
text_list = [text_list]
|
| 340 |
+
|
| 341 |
+
results = []
|
| 342 |
+
for text in text_list:
|
| 343 |
+
if not text:
|
| 344 |
+
results.append({"translation_text": ""})
|
| 345 |
+
continue
|
| 346 |
+
|
| 347 |
+
# Prepare input
|
| 348 |
+
inputs = self.translator_tokenizer(text, return_tensors="pt").to(device)
|
| 349 |
+
inputs["forced_bos_token_id"] = self.translator_tokenizer.lang_code_to_id["rus_Cyrl"]
|
| 350 |
+
|
| 351 |
+
# Generate translation
|
| 352 |
+
with torch.no_grad():
|
| 353 |
+
outputs = self.translator_model.generate(
|
| 354 |
+
**inputs,
|
| 355 |
+
forced_bos_token_id=self.translator_tokenizer.lang_code_to_id["rus_Cyrl"],
|
| 356 |
+
max_length=512,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# Decode and format like Helsinki-NLP output
|
| 360 |
+
translation = self.translator_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
| 361 |
+
results.append({"translation_text": translation})
|
| 362 |
+
|
| 363 |
+
return results
|
| 364 |
+
|
| 365 |
+
# Replace pipeline with custom functions that mimic the original API
|
| 366 |
+
self.translator = translate_ru_en
|
| 367 |
+
self.rutranslator = translate_en_ru
|
| 368 |
+
|
| 369 |
+
logger.info("Translation models replaced successfully!")
|
| 370 |
+
|
| 371 |
|
| 372 |
# Initialize sentiment models
|
| 373 |
self.finbert = pipeline(
|