Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
# code revision v9
|
| 2 |
-
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
| 5 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
@@ -147,7 +145,12 @@ class NLLBTranslator:
|
|
| 147 |
if source_lang == target_lang:
|
| 148 |
return text
|
| 149 |
|
| 150 |
-
logger.info(f"Translating from {source_lang} to {target_lang}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
# Check if simple or complex text
|
| 153 |
if '\n' not in text and len(text.split('.')) <= 2:
|
|
@@ -162,8 +165,44 @@ class NLLBTranslator:
|
|
| 162 |
|
| 163 |
except Exception as e:
|
| 164 |
logger.error(f"Translation error: {str(e)}")
|
|
|
|
|
|
|
| 165 |
return f"Error during translation: {str(e)}"
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
def perform_translation(self, input_sentences: list, source_code: str, target_code: str, paragraph_markers: list) -> str:
|
| 168 |
"""Perform the actual translation using NLLB model"""
|
| 169 |
batch_size = 2 # Conservative batch size for stability
|
|
@@ -174,6 +213,7 @@ class NLLBTranslator:
|
|
| 174 |
batch_size = 1
|
| 175 |
|
| 176 |
logger.info(f"Using batch size {batch_size} for average sentence length {avg_sentence_length:.1f} words")
|
|
|
|
| 177 |
|
| 178 |
all_translations = []
|
| 179 |
|
|
@@ -181,7 +221,8 @@ class NLLBTranslator:
|
|
| 181 |
batch_sentences = input_sentences[i:i + batch_size]
|
| 182 |
|
| 183 |
try:
|
| 184 |
-
# Tokenize input
|
|
|
|
| 185 |
inputs = self.tokenizer(
|
| 186 |
batch_sentences,
|
| 187 |
return_tensors="pt",
|
|
@@ -190,15 +231,24 @@ class NLLBTranslator:
|
|
| 190 |
max_length=512
|
| 191 |
).to(self.device)
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
# Generate translation
|
| 194 |
with torch.no_grad():
|
| 195 |
translated_tokens = self.model.generate(
|
| 196 |
**inputs,
|
| 197 |
-
forced_bos_token_id=
|
| 198 |
max_length=512,
|
| 199 |
num_beams=4,
|
| 200 |
early_stopping=True,
|
| 201 |
-
do_sample=False
|
|
|
|
|
|
|
| 202 |
)
|
| 203 |
|
| 204 |
# Decode translations
|
|
@@ -207,7 +257,17 @@ class NLLBTranslator:
|
|
| 207 |
skip_special_tokens=True
|
| 208 |
)
|
| 209 |
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
# Progress logging
|
| 213 |
if len(input_sentences) > 10:
|
|
@@ -217,9 +277,12 @@ class NLLBTranslator:
|
|
| 217 |
except Exception as e:
|
| 218 |
logger.error(f"Translation error in batch: {str(e)}")
|
| 219 |
|
| 220 |
-
# Fallback: process sentences individually
|
| 221 |
for single_sentence in batch_sentences:
|
| 222 |
try:
|
|
|
|
|
|
|
|
|
|
| 223 |
inputs = self.tokenizer(
|
| 224 |
single_sentence,
|
| 225 |
return_tensors="pt",
|
|
@@ -227,31 +290,47 @@ class NLLBTranslator:
|
|
| 227 |
max_length=512
|
| 228 |
).to(self.device)
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
with torch.no_grad():
|
| 231 |
-
translated_tokens = self.model.generate(
|
| 232 |
-
**inputs,
|
| 233 |
-
forced_bos_token_id=self.tokenizer.lang_code_to_id.get(target_code, self.tokenizer.eos_token_id),
|
| 234 |
-
max_length=512,
|
| 235 |
-
num_beams=4,
|
| 236 |
-
early_stopping=True
|
| 237 |
-
)
|
| 238 |
|
| 239 |
translation = self.tokenizer.decode(
|
| 240 |
translated_tokens[0],
|
| 241 |
skip_special_tokens=True
|
| 242 |
)
|
| 243 |
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
except Exception as single_e:
|
| 247 |
-
logger.error(f"Failed to translate sentence: {str(single_e)}")
|
| 248 |
-
all_translations.append(f"
|
| 249 |
|
| 250 |
# Reconstruct formatting
|
| 251 |
if paragraph_markers and len(all_translations) == len(paragraph_markers):
|
| 252 |
final_translation = self.reconstruct_formatting(all_translations, paragraph_markers)
|
| 253 |
else:
|
| 254 |
-
final_translation = ' '.join(all_translations) if all_translations else "Translation failed"
|
| 255 |
|
| 256 |
return final_translation
|
| 257 |
|
|
@@ -803,7 +882,7 @@ def translate_document(file, source_lang: str, target_lang: str, session_id: str
|
|
| 803 |
|
| 804 |
# Initialize translator
|
| 805 |
print("Initializing NLLB Translator...")
|
| 806 |
-
translator = NLLBTranslator(model_size="
|
| 807 |
|
| 808 |
# Create the Gradio app
|
| 809 |
with gr.Blocks(title="NLLB Universal Translator", theme=gr.themes.Soft()) as demo:
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
| 145 |
if source_lang == target_lang:
|
| 146 |
return text
|
| 147 |
|
| 148 |
+
logger.info(f"Translating from {source_lang} ({source_code}) to {target_lang} ({target_code})")
|
| 149 |
+
|
| 150 |
+
# For simple test, try a direct approach first
|
| 151 |
+
if text.strip() == "Hello, how are you today?":
|
| 152 |
+
logger.info("Using simple test translation")
|
| 153 |
+
return self.simple_translate(text, source_code, target_code)
|
| 154 |
|
| 155 |
# Check if simple or complex text
|
| 156 |
if '\n' not in text and len(text.split('.')) <= 2:
|
|
|
|
| 165 |
|
| 166 |
except Exception as e:
|
| 167 |
logger.error(f"Translation error: {str(e)}")
|
| 168 |
+
import traceback
|
| 169 |
+
traceback.print_exc()
|
| 170 |
return f"Error during translation: {str(e)}"
|
| 171 |
|
| 172 |
+
def simple_translate(self, text: str, source_code: str, target_code: str) -> str:
|
| 173 |
+
"""Simple translation method for testing"""
|
| 174 |
+
try:
|
| 175 |
+
# Set source language
|
| 176 |
+
self.tokenizer.src_lang = source_code
|
| 177 |
+
|
| 178 |
+
# Tokenize
|
| 179 |
+
inputs = self.tokenizer(
|
| 180 |
+
text,
|
| 181 |
+
return_tensors="pt",
|
| 182 |
+
truncation=True,
|
| 183 |
+
max_length=512
|
| 184 |
+
).to(self.device)
|
| 185 |
+
|
| 186 |
+
# Generate without forced language token first
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
outputs = self.model.generate(
|
| 189 |
+
**inputs,
|
| 190 |
+
max_length=512,
|
| 191 |
+
num_beams=5,
|
| 192 |
+
early_stopping=True,
|
| 193 |
+
do_sample=False
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Decode
|
| 197 |
+
translation = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 198 |
+
logger.info(f"Simple translation result: {translation}")
|
| 199 |
+
|
| 200 |
+
return translation.strip() if translation.strip() else "Translation produced empty result"
|
| 201 |
+
|
| 202 |
+
except Exception as e:
|
| 203 |
+
logger.error(f"Simple translation failed: {str(e)}")
|
| 204 |
+
return f"Simple translation failed: {str(e)}"
|
| 205 |
+
|
| 206 |
def perform_translation(self, input_sentences: list, source_code: str, target_code: str, paragraph_markers: list) -> str:
|
| 207 |
"""Perform the actual translation using NLLB model"""
|
| 208 |
batch_size = 2 # Conservative batch size for stability
|
|
|
|
| 213 |
batch_size = 1
|
| 214 |
|
| 215 |
logger.info(f"Using batch size {batch_size} for average sentence length {avg_sentence_length:.1f} words")
|
| 216 |
+
logger.info(f"Translating from {source_code} to {target_code}")
|
| 217 |
|
| 218 |
all_translations = []
|
| 219 |
|
|
|
|
| 221 |
batch_sentences = input_sentences[i:i + batch_size]
|
| 222 |
|
| 223 |
try:
|
| 224 |
+
# Tokenize input with source language
|
| 225 |
+
self.tokenizer.src_lang = source_code
|
| 226 |
inputs = self.tokenizer(
|
| 227 |
batch_sentences,
|
| 228 |
return_tensors="pt",
|
|
|
|
| 231 |
max_length=512
|
| 232 |
).to(self.device)
|
| 233 |
|
| 234 |
+
# Get target language token ID
|
| 235 |
+
try:
|
| 236 |
+
target_token_id = self.tokenizer.lang_code_to_id[target_code]
|
| 237 |
+
except KeyError:
|
| 238 |
+
logger.warning(f"Language code {target_code} not found in tokenizer, using default")
|
| 239 |
+
target_token_id = self.tokenizer.pad_token_id
|
| 240 |
+
|
| 241 |
# Generate translation
|
| 242 |
with torch.no_grad():
|
| 243 |
translated_tokens = self.model.generate(
|
| 244 |
**inputs,
|
| 245 |
+
forced_bos_token_id=target_token_id,
|
| 246 |
max_length=512,
|
| 247 |
num_beams=4,
|
| 248 |
early_stopping=True,
|
| 249 |
+
do_sample=False,
|
| 250 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 251 |
+
eos_token_id=self.tokenizer.eos_token_id
|
| 252 |
)
|
| 253 |
|
| 254 |
# Decode translations
|
|
|
|
| 257 |
skip_special_tokens=True
|
| 258 |
)
|
| 259 |
|
| 260 |
+
# Clean up translations (remove source language tokens if present)
|
| 261 |
+
cleaned_translations = []
|
| 262 |
+
for trans in translations:
|
| 263 |
+
# Remove any language tokens that might be in the output
|
| 264 |
+
cleaned = trans.strip()
|
| 265 |
+
if cleaned:
|
| 266 |
+
cleaned_translations.append(cleaned)
|
| 267 |
+
else:
|
| 268 |
+
cleaned_translations.append("Translation produced empty result")
|
| 269 |
+
|
| 270 |
+
all_translations.extend(cleaned_translations)
|
| 271 |
|
| 272 |
# Progress logging
|
| 273 |
if len(input_sentences) > 10:
|
|
|
|
| 277 |
except Exception as e:
|
| 278 |
logger.error(f"Translation error in batch: {str(e)}")
|
| 279 |
|
| 280 |
+
# Fallback: process sentences individually with simpler approach
|
| 281 |
for single_sentence in batch_sentences:
|
| 282 |
try:
|
| 283 |
+
# Set source language
|
| 284 |
+
self.tokenizer.src_lang = source_code
|
| 285 |
+
|
| 286 |
inputs = self.tokenizer(
|
| 287 |
single_sentence,
|
| 288 |
return_tensors="pt",
|
|
|
|
| 290 |
max_length=512
|
| 291 |
).to(self.device)
|
| 292 |
|
| 293 |
+
# Try different approaches for target language
|
| 294 |
+
generation_kwargs = {
|
| 295 |
+
"max_length": 512,
|
| 296 |
+
"num_beams": 2,
|
| 297 |
+
"early_stopping": True,
|
| 298 |
+
"do_sample": False,
|
| 299 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
| 300 |
+
"eos_token_id": self.tokenizer.eos_token_id
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
# Try with target language token first
|
| 304 |
+
try:
|
| 305 |
+
target_token_id = self.tokenizer.lang_code_to_id[target_code]
|
| 306 |
+
generation_kwargs["forced_bos_token_id"] = target_token_id
|
| 307 |
+
except KeyError:
|
| 308 |
+
logger.warning(f"Target language {target_code} not in tokenizer, trying without forced_bos_token_id")
|
| 309 |
+
|
| 310 |
with torch.no_grad():
|
| 311 |
+
translated_tokens = self.model.generate(**inputs, **generation_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
translation = self.tokenizer.decode(
|
| 314 |
translated_tokens[0],
|
| 315 |
skip_special_tokens=True
|
| 316 |
)
|
| 317 |
|
| 318 |
+
# Clean the translation
|
| 319 |
+
cleaned_translation = translation.strip()
|
| 320 |
+
if cleaned_translation:
|
| 321 |
+
all_translations.append(cleaned_translation)
|
| 322 |
+
else:
|
| 323 |
+
all_translations.append("Empty translation result")
|
| 324 |
|
| 325 |
except Exception as single_e:
|
| 326 |
+
logger.error(f"Failed to translate sentence '{single_sentence}': {str(single_e)}")
|
| 327 |
+
all_translations.append(f"Translation failed: {str(single_e)}")
|
| 328 |
|
| 329 |
# Reconstruct formatting
|
| 330 |
if paragraph_markers and len(all_translations) == len(paragraph_markers):
|
| 331 |
final_translation = self.reconstruct_formatting(all_translations, paragraph_markers)
|
| 332 |
else:
|
| 333 |
+
final_translation = ' '.join(all_translations) if all_translations else "Translation failed - no output generated"
|
| 334 |
|
| 335 |
return final_translation
|
| 336 |
|
|
|
|
| 882 |
|
| 883 |
# Initialize translator
|
| 884 |
print("Initializing NLLB Translator...")
|
| 885 |
+
translator = NLLBTranslator(model_size="600M") # Use smaller model for stability
|
| 886 |
|
| 887 |
# Create the Gradio app
|
| 888 |
with gr.Blocks(title="NLLB Universal Translator", theme=gr.themes.Soft()) as demo:
|