tfg-api / app.py
asierfg794's picture
EasyOCR + NLLB optimizado 2
2c09eb9
"""
app.py — FastAPI + EasyOCR + Gemini + NLLB + HiTZ zerbitzaria
OCR + postzuzenketa (Gemini 2.5 Flash) + itzulpena (NLLB-200 + HiTZ Marian).
"""
import io
import logging
import os
import re
import time
from contextlib import asynccontextmanager
import easyocr
import httpx
import numpy as np
import torch
from deskew import determine_skew
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from PIL import Image
from skimage.transform import rotate
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
MarianMTModel,
MarianTokenizer,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# OHARRA: devanagari eta thai kendu dira RAM aurrezteko.
SCRIPTS = {
"latin": ["en","es","fr","de","it","pt","nl","pl","cs","sk","hr",
"ro","hu","lt","lv","et","sv","da","no","is","mt","sq","tr","vi"],
"cyrillic": ["en","ru","bg","uk","be","rs_cyrillic","mn"],
"arabic": ["en","ar","fa","ur"],
"chinese": ["en","ch_sim"],
"japanese": ["en","ja"],
"korean": ["en","ko"],
}
MAX_SIDE = 1280
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
GEMINI_MODEL = "gemini-2.5-flash"
GEMINI_URL = (
f"https://generativelanguage.googleapis.com/v1beta/models/"
f"{GEMINI_MODEL}:generateContent"
)
CORRECTION_PROMPT = """Eres un corrector de texto extraído por OCR. Corrige los errores del OCR y devuelve únicamente el texto corregido, sin explicaciones ni comentarios.
TAREA PRIORITARIA — UNIR LÍNEAS QUE PERTENECEN A LA MISMA FRASE:
El OCR rompe frases en varias líneas porque la cámara captura imágenes con ancho de columna fijo. Tu trabajo es DESHACER esos cortes artificiales y juntar en una sola línea las líneas consecutivas que, por contexto, formen parte de la misma frase u oración.
REGLAS PARA UNIR LÍNEAS:
1. Si una línea termina en punto final ('.'), NUNCA la unas con la siguiente. Mantén el salto de línea entre ellas.
2. Si una línea termina en signo de cierre fuerte (! ? :), tampoco la unas con la siguiente.
3. En cualquier otro caso (línea sin punto final), evalúa si la línea siguiente continúa la idea gramaticalmente: si lo hace, únelas en una sola línea separadas por un único espacio.
4. Une también casos típicos como:
- Línea que termina con guion de partición de palabra ('-' al final): pega las dos mitades sin espacio y sin el guion.
- Línea que termina en coma, punto y coma, conjunción ('y', 'o', 'pero', 'and', 'or', 'eta', 'edo', etc.) o preposición: probablemente continúa, únelas.
- Línea que termina en mitad de un sintagma (artículo, adjetivo sin sustantivo, etc.): únelas.
5. Conserva los saltos de línea estructurales (líneas en blanco entre párrafos, títulos, encabezados, listas numeradas o con viñetas, datos tabulares, etiquetas tipo "Name:", "Date:").
CORREGIR TAMBIÉN:
- Caracteres confundidos: l/1/I, 0/O, rn/m, 5/S, 8/B, 6/G, cl/d, vv/w, etc.
- Tildes y diacríticos que faltan o están mal cuando el contexto lo deja claro.
- Palabras rotas o mal leídas cuando el significado correcto es evidente por el contexto.
- Puntuación claramente errónea.
- Espacios de más o de menos entre palabras.
NO TOCAR:
- No traduzcas. Mantén el idioma original (puede haber varios idiomas en el mismo texto).
- No reescribas ni mejores el estilo. Solo corrige errores y une líneas.
- No cambies nombres propios, marcas, códigos, fechas, precios, URLs ni números, salvo error de OCR evidente.
- No añadas ni elimines información.
- Si una palabra es ilegible y no puedes deducirla con certeza, déjala como está.
OCR text:
---
{text}
---
Corrected text:"""
readers: dict = {}
NLLB_MODEL_NAME = "facebook/nllb-200-distilled-600M"
HITZ_PAIRS = {
("en", "eu"): "HiTZ/mt-hitz-en-eu",
("eu", "en"): "HiTZ/mt-hitz-eu-en",
("es", "eu"): "HiTZ/mt-hitz-es-eu",
("eu", "es"): "HiTZ/mt-hitz-eu-es",
("gl", "eu"): "HiTZ/mt-hitz-gl-eu",
("ca", "eu"): "HiTZ/mt-hitz-ca-eu",
}
ISO_TO_NLLB = {
"ace": "ace_Latn", "acm": "acm_Arab", "acq": "acq_Arab", "aeb": "aeb_Arab",
"af": "afr_Latn", "ajp": "ajp_Arab", "ak": "aka_Latn", "am": "amh_Ethi",
"apc": "apc_Arab", "arb": "arb_Arab", "ars": "ars_Arab", "ary": "ary_Arab",
"arz": "arz_Arab", "as": "asm_Beng", "ast": "ast_Latn", "awa": "awa_Deva",
"ay": "ayr_Latn", "azb": "azb_Arab", "az": "azj_Latn", "ba": "bak_Cyrl",
"bm": "bam_Latn", "ban": "ban_Latn", "be": "bel_Cyrl", "bem": "bem_Latn",
"bn": "ben_Beng", "bho": "bho_Deva", "bjn": "bjn_Latn", "bo": "bod_Tibt",
"bs": "bos_Latn", "bug": "bug_Latn", "bg": "bul_Cyrl", "ca": "cat_Latn",
"ceb": "ceb_Latn", "cs": "ces_Latn", "cjk": "cjk_Latn", "ckb": "ckb_Arab",
"crh": "crh_Latn", "cy": "cym_Latn", "da": "dan_Latn", "de": "deu_Latn",
"dik": "dik_Latn", "dyu": "dyu_Latn", "dz": "dzo_Tibt", "el": "ell_Grek",
"en": "eng_Latn", "eo": "epo_Latn", "et": "est_Latn", "eu": "eus_Latn",
"ee": "ewe_Latn", "fo": "fao_Latn", "fj": "fij_Latn", "fi": "fin_Latn",
"fon": "fon_Latn", "fr": "fra_Latn", "fur": "fur_Latn", "fuv": "fuv_Latn",
"gd": "gla_Latn", "ga": "gle_Latn", "gl": "glg_Latn", "gn": "grn_Latn",
"gu": "guj_Gujr", "ht": "hat_Latn", "ha": "hau_Latn", "he": "heb_Hebr",
"hi": "hin_Deva", "hne": "hne_Deva", "hr": "hrv_Latn", "hu": "hun_Latn",
"hy": "hye_Armn", "ig": "ibo_Latn", "ilo": "ilo_Latn", "id": "ind_Latn",
"is": "isl_Latn", "it": "ita_Latn", "jv": "jav_Latn", "ja": "jpn_Jpan",
"kab": "kab_Latn", "kac": "kac_Latn", "kam": "kam_Latn", "kn": "kan_Knda",
"ks": "kas_Arab", "ka": "kat_Geor", "knc": "knc_Latn", "kk": "kaz_Cyrl",
"kbp": "kbp_Latn", "kea": "kea_Latn", "km": "khm_Khmr", "ki": "kik_Latn",
"rw": "kin_Latn", "ky": "kir_Cyrl", "kmb": "kmb_Latn", "kmr": "kmr_Latn",
"kg": "kon_Latn", "ko": "kor_Hang", "lo": "lao_Laoo", "lij": "lij_Latn",
"li": "lim_Latn", "ln": "lin_Latn", "lt": "lit_Latn", "lmo": "lmo_Latn",
"ltg": "ltg_Latn", "lb": "ltz_Latn", "lua": "lua_Latn", "lg": "lug_Latn",
"luo": "luo_Latn", "lus": "lus_Latn", "lv": "lvs_Latn", "mag": "mag_Deva",
"mai": "mai_Deva", "ml": "mal_Mlym", "mr": "mar_Deva", "min": "min_Latn",
"mk": "mkd_Cyrl", "mg": "plt_Latn", "mt": "mlt_Latn", "mni": "mni_Beng",
"mn": "khk_Cyrl", "mos": "mos_Latn", "mi": "mri_Latn", "my": "mya_Mymr",
"nl": "nld_Latn", "nn": "nno_Latn", "nb": "nob_Latn", "ne": "npi_Deva",
"nso": "nso_Latn", "nus": "nus_Latn", "ny": "nya_Latn", "oc": "oci_Latn",
"om": "gaz_Latn", "or": "ory_Orya", "pag": "pag_Latn", "pa": "pan_Guru",
"pap": "pap_Latn", "fa": "pes_Arab", "pl": "pol_Latn", "pt": "por_Latn",
"prs": "prs_Arab", "ps": "pbt_Arab", "qu": "quy_Latn", "ro": "ron_Latn",
"rn": "run_Latn", "ru": "rus_Cyrl", "sg": "sag_Latn", "sa": "san_Deva",
"sat": "sat_Olck", "scn": "scn_Latn", "shn": "shn_Mymr", "si": "sin_Sinh",
"sk": "slk_Latn", "sl": "slv_Latn", "sm": "smo_Latn", "sn": "sna_Latn",
"sd": "snd_Arab", "so": "som_Latn", "st": "sot_Latn", "es": "spa_Latn",
"sq": "als_Latn", "sc": "srd_Latn", "sr": "srp_Cyrl", "ss": "ssw_Latn",
"su": "sun_Latn", "sv": "swe_Latn", "sw": "swh_Latn", "szl": "szl_Latn",
"ta": "tam_Taml", "tt": "tat_Cyrl", "te": "tel_Telu", "tg": "tgk_Cyrl",
"tl": "tgl_Latn", "th": "tha_Thai", "ti": "tir_Ethi", "taq": "taq_Latn",
"tpi": "tpi_Latn", "tn": "tsn_Latn", "ts": "tso_Latn", "tk": "tuk_Latn",
"tum": "tum_Latn", "tr": "tur_Latn", "tw": "twi_Latn", "tzm": "tzm_Tfng",
"ug": "uig_Arab", "uk": "ukr_Cyrl", "umb": "umb_Latn", "ur": "urd_Arab",
"uz": "uzn_Latn", "vec": "vec_Latn", "vi": "vie_Latn", "war": "war_Latn",
"wo": "wol_Latn", "xh": "xho_Latn", "yi": "ydd_Hebr", "yo": "yor_Latn",
"yue": "yue_Hant", "zh": "zho_Hans", "zht": "zho_Hant", "ms": "zsm_Latn",
"zu": "zul_Latn",
}
nllb_model = None
nllb_tokenizer = None
hitz_models: dict = {}
def _resize(img: Image.Image) -> Image.Image:
w, h = img.size
longest = max(w, h)
if longest <= MAX_SIDE:
return img
scale = MAX_SIDE / longest
new_w, new_h = int(w * scale), int(h * scale)
logger.info("[RESIZE] %dx%d -> %dx%d", w, h, new_w, new_h)
return img.resize((new_w, new_h), Image.LANCZOS)
def _deskew(img_array: np.ndarray) -> np.ndarray:
gray = np.mean(img_array, axis=2).astype(np.uint8)
angle = determine_skew(gray)
if angle is None or abs(angle) < 0.5 or abs(angle) > 15:
return img_array
logger.info("[DESKEW] %.2f gradu zuzendu", angle)
rotated = rotate(img_array, angle, resize=True, cval=1.0)
return (rotated * 255).astype(np.uint8)
def _group_into_lines(ocr_results: list) -> str:
if not ocr_results:
return ""
items = []
for bbox, text, score in ocr_results:
ys = [pt[1] for pt in bbox]
xs = [pt[0] for pt in bbox]
items.append({
"text": text,
"y": (min(ys) + max(ys)) / 2,
"x": min(xs),
"h": max(ys) - min(ys),
})
items.sort(key=lambda it: it["y"])
lines = []
current = [items[0]]
for it in items[1:]:
avg_h = sum(c["h"] for c in current) / len(current)
if abs(it["y"] - current[-1]["y"]) <= avg_h * 0.6:
current.append(it)
else:
lines.append(current)
current = [it]
if current:
lines.append(current)
output_lines = []
for line in lines:
line.sort(key=lambda it: it["x"])
output_lines.append(" ".join(it["text"] for it in line))
if len(lines) <= 1:
return output_lines[0] if output_lines else ""
line_ys = [sum(it["y"] for it in l) / len(l) for l in lines]
line_hs = [sum(it["h"] for it in l) / len(l) for l in lines]
final = [output_lines[0]]
for i in range(1, len(output_lines)):
if line_ys[i] - line_ys[i - 1] > line_hs[i - 1] * 1.8:
final.append("")
final.append(output_lines[i])
return "\n".join(final)
async def _gemini_correct(text: str) -> str:
if not text.strip():
return text
if not GEMINI_API_KEY:
logger.warning("[GEMINI] GEMINI_API_KEY ez dago konfiguratuta")
return text
payload = {
"contents": [{"parts": [{"text": CORRECTION_PROMPT.format(text=text)}]}],
"generationConfig": {"temperature": 0.1, "maxOutputTokens": 4096},
}
try:
async with httpx.AsyncClient(timeout=20.0) as client:
response = await client.post(
f"{GEMINI_URL}?key={GEMINI_API_KEY}", json=payload
)
response.raise_for_status()
data = response.json()
except Exception as e:
logger.warning("[GEMINI] Errorea: %s", e)
return text
try:
corrected = data["candidates"][0]["content"]["parts"][0]["text"].strip()
logger.info("[GEMINI] Zuzenduta. %d -> %d kar.", len(text), len(corrected))
return corrected
except (KeyError, IndexError) as e:
logger.warning("[GEMINI] Erantzun-formatu okerra: %s", e)
return text
_SENT_END_RE = re.compile(r'(?<=[.!?])["\u00bb\u2019\')\]]?\s')
_MAX_CHARS_PER_CHUNK = 1200
def _flatten_to_sentences(text: str):
blocks = []
buffer = ""
for raw_line in text.split("\n"):
line = raw_line.strip()
if not line:
if buffer.strip():
blocks.append(buffer.strip())
buffer = ""
blocks.append(None)
continue
buffer = (buffer + " " + line).strip() if buffer else line
if buffer.strip():
blocks.append(buffer.strip())
sentence_blocks = []
for b in blocks:
if b is None:
sentence_blocks.append(None)
continue
parts = _SENT_END_RE.split(b)
for p in parts:
p = p.strip()
if not p:
continue
if len(p) <= _MAX_CHARS_PER_CHUNK:
sentence_blocks.append(p)
else:
words = p.split(" ")
cur = ""
for w in words:
if len(cur) + len(w) + 1 > _MAX_CHARS_PER_CHUNK:
if cur:
sentence_blocks.append(cur.strip())
cur = w
else:
cur = (cur + " " + w).strip() if cur else w
if cur.strip():
sentence_blocks.append(cur.strip())
return sentence_blocks
def _rebuild(blocks_in, translations):
out_paragraphs = []
current = []
ti = 0
for b in blocks_in:
if b is None:
if current:
out_paragraphs.append(" ".join(current))
current = []
out_paragraphs.append("")
else:
current.append(translations[ti])
ti += 1
if current:
out_paragraphs.append(" ".join(current))
return "\n".join(out_paragraphs)
def _adaptive_max_tokens(sentence: str) -> int:
"""Token-kopuru maximoa estimatu sarrera-luzeraren arabera."""
approx_src_tokens = max(8, len(sentence) // 4)
return min(512, max(32, int(approx_src_tokens * 1.8)))
def _nllb_translate(text: str, src_nllb: str, tgt_nllb: str) -> str:
"""NLLB-200 ereduarekin itzuli, esaldika eta batch-ean."""
if not text.strip():
return text
blocks = _flatten_to_sentences(text)
to_translate = [b for b in blocks if b is not None]
if not to_translate:
return text
logger.info("[NLLB] %s -> %s | %d esaldi", src_nllb, tgt_nllb, len(to_translate))
t0 = time.time()
nllb_tokenizer.src_lang = src_nllb
forced_bos = nllb_tokenizer.convert_tokens_to_ids(tgt_nllb)
logger.info("[NLLB] forced_bos_token_id(%s) = %s", tgt_nllb, forced_bos)
translations = []
BATCH = 8
for i in range(0, len(to_translate), BATCH):
chunk = to_translate[i:i + BATCH]
max_new = max(_adaptive_max_tokens(s) for s in chunk)
inputs = nllb_tokenizer(
chunk, return_tensors="pt", padding=True,
truncation=True, max_length=512,
)
with torch.no_grad():
outputs = nllb_model.generate(
**inputs,
forced_bos_token_id=forced_bos,
max_new_tokens=max_new,
num_beams=2,
no_repeat_ngram_size=3,
early_stopping=True,
)
decoded = nllb_tokenizer.batch_decode(outputs, skip_special_tokens=True)
for src_s, out_s in zip(chunk, decoded):
logger.info("[NLLB] %r -> %r", src_s[:60], out_s[:60])
translations.extend(decoded)
logger.info("[NLLB] Egina %.1fs-tan", time.time() - t0)
return _rebuild(blocks, [t.strip() for t in translations])
def _hitz_translate(text: str, src: str, tgt: str) -> str:
"""
HiTZ Marian ereduarekin itzuli, esaldika.
README-ko kode ofiziala erabiltzen da: tokenizer + generate() defaults soilik.
"""
if not text.strip():
return text
bundle = hitz_models.get((src, tgt))
if bundle is None:
raise ValueError(f"HiTZ bikote ezezaguna: {src}->{tgt}")
tokenizer = bundle["tokenizer"]
model = bundle["model"]
blocks = _flatten_to_sentences(text)
to_translate = [b for b in blocks if b is not None]
if not to_translate:
return text
logger.info("[HITZ] %s -> %s | %d esaldi", src, tgt, len(to_translate))
t0 = time.time()
translations = []
for sentence in to_translate:
# README-ko erabilera ofiziala: ez parametro gehigarririk
inputs = tokenizer([sentence], return_tensors="pt", padding=True)
logger.info("[HITZ] input_ids shape: %s", inputs["input_ids"].shape)
with torch.no_grad():
translated = model.generate(**inputs)
result = tokenizer.decode(translated[0], skip_special_tokens=True).strip()
logger.info("[HITZ] %r -> %r", sentence[:60], result[:60])
translations.append(result)
logger.info("[HITZ] Egina %.1fs-tan", time.time() - t0)
return _rebuild(blocks, translations)
def translate(text: str, src: str, tgt: str) -> str:
if src == tgt:
logger.info("[TRANSLATE] src==tgt (%s) -> aldaketarik gabe", src)
return text
if src not in ISO_TO_NLLB:
raise HTTPException(status_code=400, detail=f"Hizkuntza ez da onartzen: {src}")
if tgt not in ISO_TO_NLLB:
raise HTTPException(status_code=400, detail=f"Hizkuntza ez da onartzen: {tgt}")
# PROBA: NLLB soilik (HiTZ aldi baterako desaktibatuta bateragarritasun arazoak direla eta)
logger.info("[TRANSLATE] NLLB zuzenean: %s -> %s", src, tgt)
return _nllb_translate(text, ISO_TO_NLLB[src], ISO_TO_NLLB[tgt])
@asynccontextmanager
async def lifespan(app: FastAPI):
for name, langs in SCRIPTS.items():
logger.info("Reader kargatzen (quantize=True): %s %s", name, langs)
readers[name] = easyocr.Reader(langs, gpu=False, quantize=True)
global nllb_model, nllb_tokenizer
logger.info("[LOAD] NLLB eredua kargatzen: %s", NLLB_MODEL_NAME)
nllb_tokenizer = AutoTokenizer.from_pretrained(NLLB_MODEL_NAME)
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(NLLB_MODEL_NAME)
nllb_model.eval()
logger.info("[LOAD] NLLB mota: %s | tokenizer: %s",
nllb_model.__class__.__name__,
nllb_tokenizer.__class__.__name__)
# HiTZ aldi baterako desaktibatuta (transformers bateragarritasun arazoak)
logger.info("[LOAD] HiTZ karga saltatzen (NLLB soilik modua)")
if GEMINI_API_KEY:
logger.info("[LOAD] Gemini konfiguratuta: %s", GEMINI_MODEL)
else:
logger.warning("[LOAD] Gemini API key gabe")
logger.info("[LOAD] Sistema prest.")
yield
readers.clear()
hitz_models.clear()
app = FastAPI(title="OCR + Itzulpena API", version="16.0.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST", "GET"],
allow_headers=["*"],
)
@app.get("/")
async def health_check():
return {
"status": "ok",
"scripts": list(readers.keys()),
"gemini": bool(GEMINI_API_KEY),
"nllb": nllb_model is not None,
"nllb_class": nllb_model.__class__.__name__ if nllb_model else None,
"hitz_pairs": [f"{s}-{t}" for (s, t) in hitz_models.keys()],
}
@app.post("/predict")
async def predict(
image: UploadFile = File(...),
script: str = Form(default="latin"),
correct: str = Form(default="true"),
):
if script not in readers:
raise HTTPException(status_code=400, detail=f"Script ezezaguna: '{script}'.")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Irudi baliogabea.")
logger.info("[REQUEST] %s (%dx%d) script=%s correct=%s",
image.filename, *pil_image.size, script, correct)
t0 = time.time()
pil_image = _resize(pil_image)
img_array = np.array(pil_image)
img_array = _deskew(img_array)
reader = readers[script]
results = reader.readtext(img_array, detail=1, paragraph=False)
raw_text = _group_into_lines(results)
logger.info("[OCR] Egina %.1fs-tan, %d karaktere", time.time() - t0, len(raw_text))
if correct.lower() == "true":
text = await _gemini_correct(raw_text)
else:
text = raw_text
logger.info("[RESPONSE] %d karaktere (%.1fs guztira)", len(text), time.time() - t0)
return JSONResponse(content={"text": text})
@app.post("/translate")
def translate_endpoint(
text: str = Form(...),
source_lang: str = Form(...),
target_lang: str = Form(...),
):
logger.info("[TRANSLATE] === %s -> %s (%d kar.) ===",
source_lang, target_lang, len(text))
t0 = time.time()
try:
translation = translate(text, source_lang, target_lang)
except HTTPException:
raise
except Exception as e:
logger.error("[TRANSLATE] Errorea: %s", e)
raise HTTPException(status_code=500, detail=f"Itzulpen-errorea: {e}")
logger.info("[TRANSLATE] === Egina (%d kar., %.1fs) ===",
len(translation), time.time() - t0)
return JSONResponse(content={"translation": translation})