from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from hf_hub_ctranslate2 import TranslatorCT2fromHfHub from transformers import AutoTokenizer import time import re import logging from typing import Optional logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Helsinki CTranslate2 Translation API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ============================================================================= # Кэш моделей — загружаем один раз # ============================================================================= model_cache = {} # Маппинг пар языков на CTranslate2 модели на HuggingFace # Используем уже сконвертированные модели от gaudi/ CT2_MODELS = { "en-ru": "Helsinki-NLP/opus-mt-en-ru", # будем конвертировать на лету "ru-en": "Helsinki-NLP/opus-mt-ru-en", "en-uk": "Helsinki-NLP/opus-mt-en-uk", "en-de": "Helsinki-NLP/opus-mt-en-de", "en-fr": "Helsinki-NLP/opus-mt-en-ROMANCE", "en-zh": "Helsinki-NLP/opus-mt-en-zh", } def get_model(src_lang: str, tgt_lang: str): key = f"{src_lang}-{tgt_lang}" if key not in model_cache: logger.info(f"⏳ Загружаем модель: {key}") # Пробуем взять готовую CTranslate2 модель ct2_model_name = f"gaudi/opus-mt-{src_lang}-{tgt_lang}-ctranslate2" hf_model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}" try: # Сначала пробуем готовую CT2 модель от gaudi tokenizer = AutoTokenizer.from_pretrained(ct2_model_name) model = TranslatorCT2fromHfHub( model_name_or_path=ct2_model_name, device="cpu", compute_type="int8", # INT8 — максимальная скорость на CPU tokenizer=tokenizer ) logger.info(f"✅ Загружена CTranslate2 модель: {ct2_model_name}") except Exception as e: logger.warning(f"CT2 модель не найдена ({e}), конвертируем из Helsinki...") # Конвертируем на лету из оригинальной Helsinki модели import ctranslate2 import subprocess, os converted_path = f"/tmp/ct2_{src_lang}_{tgt_lang}" if not os.path.exists(converted_path): subprocess.run([ "ct2-transformers-converter", "--model", hf_model_name, "--output_dir", converted_path, "--quantization", "int8", "--force" ], check=True) tokenizer = AutoTokenizer.from_pretrained(hf_model_name) model = ctranslate2.Translator( converted_path, device="cpu", compute_type="int8", inter_threads=2, intra_threads=4, ) # Оборачиваем в простой класс для единого интерфейса model = _CT2Wrapper(model, tokenizer) logger.info(f"✅ Сконвертирована и загружена: {hf_model_name}") model_cache[key] = model return model_cache[key] class _CT2Wrapper: """Обёртка над ctranslate2.Translator для единого интерфейса""" def __init__(self, translator, tokenizer): self.translator = translator self.tokenizer = tokenizer def generate(self, text: list[str]) -> list[str]: results = [] for t in text: tokens = self.tokenizer.convert_ids_to_tokens( self.tokenizer.encode(t) ) out = self.translator.translate_batch([tokens]) decoded = self.tokenizer.decode( self.tokenizer.convert_tokens_to_ids(out[0].hypotheses[0]), skip_special_tokens=True ) results.append(decoded) return results def split_text(text: str, max_chars: int = 400) -> list[str]: """Разбивает длинный текст на части по предложениям""" sentences = re.split(r'(?<=[.!?])\s+', text) chunks, current = [], "" for s in sentences: if len(current) + len(s) < max_chars: current += s + " " else: if current.strip(): chunks.append(current.strip()) current = s + " " if current.strip(): chunks.append(current.strip()) return chunks or [text] def do_translate(text: str, src_lang: str, tgt_lang: str) -> str: if not text.strip(): return text model = get_model(src_lang, tgt_lang) chunks = split_text(text) # Батч перевод — CTranslate2 очень эффективен на батчах translated = model.generate(chunks) return " ".join(translated) # ============================================================================= # ENDPOINTS # ============================================================================= @app.get("/") def root(): return { "message": "Helsinki CTranslate2 API — 6-10x быстрее обычного Helsinki", "engine": "CTranslate2 INT8", "docs": "/docs" } @app.get("/health") def health(): return { "status": "ok", "engine": "CTranslate2", "compute_type": "int8", "loaded_models": list(model_cache.keys()) } @app.get("/v1/models") def list_models(): return { "object": "list", "data": [{"id": "helsinki-en-ru", "object": "model"}] } class TranslateRequest(BaseModel): text: str src_lang: str = "en" tgt_lang: str = "ru" @app.post("/translate") def translate(req: TranslateRequest): if not req.text.strip(): raise HTTPException(status_code=400, detail="Пустой текст") t0 = time.time() result = do_translate(req.text, req.src_lang, req.tgt_lang) elapsed = round(time.time() - t0, 2) return { "translated_text": result, "src_lang": req.src_lang, "tgt_lang": req.tgt_lang, "elapsed_sec": elapsed, "engine": "ctranslate2-int8" } @app.post("/translate/batch") def translate_batch( texts: list[str], src_lang: str = "en", tgt_lang: str = "ru" ): if not texts: raise HTTPException(status_code=400, detail="Пустой список") if len(texts) > 200: raise HTTPException(status_code=400, detail="Максимум 200 текстов") t0 = time.time() model = get_model(src_lang, tgt_lang) results = model.generate(texts) elapsed = round(time.time() - t0, 2) return { "translations": results, "count": len(results), "elapsed_sec": elapsed, "engine": "ctranslate2-int8" } # ============================================================================= # OpenAI-совместимый endpoint для pdf2zh # ============================================================================= @app.post("/v1/chat/completions") async def openai_chat(request: Request): try: body = await request.json() messages = body.get("messages", []) model_name = body.get("model", "helsinki-en-ru") parts = model_name.split("-") src_lang = parts[-2] if len(parts) >= 3 and len(parts[-2]) == 2 else "en" tgt_lang = parts[-1] if len(parts) >= 3 and len(parts[-1]) == 2 else "ru" raw_content = next( (m.get("content", "") for m in messages if m.get("role") == "user"), "" ) # Извлекаем текст между "Source text:" и "Translated text:" match = re.search( r'(?:Source text)[:\s]*\n?(.*?)(?:\n?(?:Translated text)[:\s]*)', raw_content, re.IGNORECASE | re.DOTALL ) text_to_translate = match.group(1).strip() if match else raw_content logger.info(f"🔤 [{src_lang}→{tgt_lang}]: {text_to_translate[:100]}") translated = do_translate(text_to_translate, src_lang, tgt_lang) logger.info(f"✅ {translated[:100]}") return { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion", "created": int(time.time()), "model": model_name, "choices": [{ "index": 0, "message": {"role": "assistant", "content": translated}, "finish_reason": "stop" }], "usage": { "prompt_tokens": len(text_to_translate.split()), "completion_tokens": len(translated.split()), "total_tokens": len(text_to_translate.split()) + len(translated.split()) } } except Exception as e: logger.error(f"❌ {e}") raise HTTPException(status_code=500, detail=str(e))