Update app.py
Browse filesadded chunking service references
app.py
CHANGED
|
@@ -2,6 +2,9 @@ from fastapi import FastAPI, Request
|
|
| 2 |
from transformers import MarianMTModel, MarianTokenizer
|
| 3 |
import torch
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
app = FastAPI()
|
| 6 |
|
| 7 |
# Map target languages to Hugging Face model IDs
|
|
@@ -60,14 +63,29 @@ async def translate(request: Request):
|
|
| 60 |
if not model_id:
|
| 61 |
return {"error": f"No model found for target language '{target_lang}'"}
|
| 62 |
|
|
|
|
| 63 |
if model_id.startswith("facebook/"):
|
| 64 |
return {"translation": f"[{target_lang}] uses model '{model_id}', which is not supported in this Space yet."}
|
| 65 |
|
| 66 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
tokenizer, model = load_model(model_id)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
except Exception as e:
|
| 72 |
return {"error": f"Translation failed: {str(e)}"}
|
| 73 |
|
|
|
|
| 2 |
from transformers import MarianMTModel, MarianTokenizer
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
# import chunking service
|
| 6 |
+
from chunking import get_max_word_length, chunk_text
|
| 7 |
+
|
| 8 |
app = FastAPI()
|
| 9 |
|
| 10 |
# Map target languages to Hugging Face model IDs
|
|
|
|
| 63 |
if not model_id:
|
| 64 |
return {"error": f"No model found for target language '{target_lang}'"}
|
| 65 |
|
| 66 |
+
# Facebook/mbart placeholder check
|
| 67 |
if model_id.startswith("facebook/"):
|
| 68 |
return {"translation": f"[{target_lang}] uses model '{model_id}', which is not supported in this Space yet."}
|
| 69 |
|
| 70 |
try:
|
| 71 |
+
# 1. figure out your safe word limit for this language
|
| 72 |
+
safe_limit = get_max_word_length([target_lang])
|
| 73 |
+
|
| 74 |
+
# 2. break the input up into chunks
|
| 75 |
+
chunks = chunk_text(text, safe_limit)
|
| 76 |
+
|
| 77 |
+
# 3. translate each chunk and collect results
|
| 78 |
tokenizer, model = load_model(model_id)
|
| 79 |
+
full_translation = []
|
| 80 |
+
for chunk in chunks:
|
| 81 |
+
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True).to(model.device)
|
| 82 |
+
outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
|
| 83 |
+
full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 84 |
+
|
| 85 |
+
# 4. re-join the translated pieces
|
| 86 |
+
joined = " ".join(full_translation)
|
| 87 |
+
return {"translation": joined}
|
| 88 |
+
|
| 89 |
except Exception as e:
|
| 90 |
return {"error": f"Translation failed: {str(e)}"}
|
| 91 |
|