ctranslate / app.py
anatoli72's picture
Upload 4 files
0f6afd4 verified
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from hf_hub_ctranslate2 import TranslatorCT2fromHfHub
from transformers import AutoTokenizer
import time
import re
import logging
from typing import Optional
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="Helsinki CTranslate2 Translation API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# =============================================================================
# Кэш моделей — загружаем один раз
# =============================================================================
model_cache = {}
# Маппинг пар языков на CTranslate2 модели на HuggingFace
# Используем уже сконвертированные модели от gaudi/
CT2_MODELS = {
"en-ru": "Helsinki-NLP/opus-mt-en-ru", # будем конвертировать на лету
"ru-en": "Helsinki-NLP/opus-mt-ru-en",
"en-uk": "Helsinki-NLP/opus-mt-en-uk",
"en-de": "Helsinki-NLP/opus-mt-en-de",
"en-fr": "Helsinki-NLP/opus-mt-en-ROMANCE",
"en-zh": "Helsinki-NLP/opus-mt-en-zh",
}
def get_model(src_lang: str, tgt_lang: str):
key = f"{src_lang}-{tgt_lang}"
if key not in model_cache:
logger.info(f"⏳ Загружаем модель: {key}")
# Пробуем взять готовую CTranslate2 модель
ct2_model_name = f"gaudi/opus-mt-{src_lang}-{tgt_lang}-ctranslate2"
hf_model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
try:
# Сначала пробуем готовую CT2 модель от gaudi
tokenizer = AutoTokenizer.from_pretrained(ct2_model_name)
model = TranslatorCT2fromHfHub(
model_name_or_path=ct2_model_name,
device="cpu",
compute_type="int8", # INT8 — максимальная скорость на CPU
tokenizer=tokenizer
)
logger.info(f"✅ Загружена CTranslate2 модель: {ct2_model_name}")
except Exception as e:
logger.warning(f"CT2 модель не найдена ({e}), конвертируем из Helsinki...")
# Конвертируем на лету из оригинальной Helsinki модели
import ctranslate2
import subprocess, os
converted_path = f"/tmp/ct2_{src_lang}_{tgt_lang}"
if not os.path.exists(converted_path):
subprocess.run([
"ct2-transformers-converter",
"--model", hf_model_name,
"--output_dir", converted_path,
"--quantization", "int8",
"--force"
], check=True)
tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
model = ctranslate2.Translator(
converted_path,
device="cpu",
compute_type="int8",
inter_threads=2,
intra_threads=4,
)
# Оборачиваем в простой класс для единого интерфейса
model = _CT2Wrapper(model, tokenizer)
logger.info(f"✅ Сконвертирована и загружена: {hf_model_name}")
model_cache[key] = model
return model_cache[key]
class _CT2Wrapper:
"""Обёртка над ctranslate2.Translator для единого интерфейса"""
def __init__(self, translator, tokenizer):
self.translator = translator
self.tokenizer = tokenizer
def generate(self, text: list[str]) -> list[str]:
results = []
for t in text:
tokens = self.tokenizer.convert_ids_to_tokens(
self.tokenizer.encode(t)
)
out = self.translator.translate_batch([tokens])
decoded = self.tokenizer.decode(
self.tokenizer.convert_tokens_to_ids(out[0].hypotheses[0]),
skip_special_tokens=True
)
results.append(decoded)
return results
def split_text(text: str, max_chars: int = 400) -> list[str]:
"""Разбивает длинный текст на части по предложениям"""
sentences = re.split(r'(?<=[.!?])\s+', text)
chunks, current = [], ""
for s in sentences:
if len(current) + len(s) < max_chars:
current += s + " "
else:
if current.strip():
chunks.append(current.strip())
current = s + " "
if current.strip():
chunks.append(current.strip())
return chunks or [text]
def do_translate(text: str, src_lang: str, tgt_lang: str) -> str:
if not text.strip():
return text
model = get_model(src_lang, tgt_lang)
chunks = split_text(text)
# Батч перевод — CTranslate2 очень эффективен на батчах
translated = model.generate(chunks)
return " ".join(translated)
# =============================================================================
# ENDPOINTS
# =============================================================================
@app.get("/")
def root():
return {
"message": "Helsinki CTranslate2 API — 6-10x быстрее обычного Helsinki",
"engine": "CTranslate2 INT8",
"docs": "/docs"
}
@app.get("/health")
def health():
return {
"status": "ok",
"engine": "CTranslate2",
"compute_type": "int8",
"loaded_models": list(model_cache.keys())
}
@app.get("/v1/models")
def list_models():
return {
"object": "list",
"data": [{"id": "helsinki-en-ru", "object": "model"}]
}
class TranslateRequest(BaseModel):
text: str
src_lang: str = "en"
tgt_lang: str = "ru"
@app.post("/translate")
def translate(req: TranslateRequest):
if not req.text.strip():
raise HTTPException(status_code=400, detail="Пустой текст")
t0 = time.time()
result = do_translate(req.text, req.src_lang, req.tgt_lang)
elapsed = round(time.time() - t0, 2)
return {
"translated_text": result,
"src_lang": req.src_lang,
"tgt_lang": req.tgt_lang,
"elapsed_sec": elapsed,
"engine": "ctranslate2-int8"
}
@app.post("/translate/batch")
def translate_batch(
texts: list[str],
src_lang: str = "en",
tgt_lang: str = "ru"
):
if not texts:
raise HTTPException(status_code=400, detail="Пустой список")
if len(texts) > 200:
raise HTTPException(status_code=400, detail="Максимум 200 текстов")
t0 = time.time()
model = get_model(src_lang, tgt_lang)
results = model.generate(texts)
elapsed = round(time.time() - t0, 2)
return {
"translations": results,
"count": len(results),
"elapsed_sec": elapsed,
"engine": "ctranslate2-int8"
}
# =============================================================================
# OpenAI-совместимый endpoint для pdf2zh
# =============================================================================
@app.post("/v1/chat/completions")
async def openai_chat(request: Request):
try:
body = await request.json()
messages = body.get("messages", [])
model_name = body.get("model", "helsinki-en-ru")
parts = model_name.split("-")
src_lang = parts[-2] if len(parts) >= 3 and len(parts[-2]) == 2 else "en"
tgt_lang = parts[-1] if len(parts) >= 3 and len(parts[-1]) == 2 else "ru"
raw_content = next(
(m.get("content", "") for m in messages if m.get("role") == "user"), ""
)
# Извлекаем текст между "Source text:" и "Translated text:"
match = re.search(
r'(?:Source text)[:\s]*\n?(.*?)(?:\n?(?:Translated text)[:\s]*)',
raw_content, re.IGNORECASE | re.DOTALL
)
text_to_translate = match.group(1).strip() if match else raw_content
logger.info(f"🔤 [{src_lang}{tgt_lang}]: {text_to_translate[:100]}")
translated = do_translate(text_to_translate, src_lang, tgt_lang)
logger.info(f"✅ {translated[:100]}")
return {
"id": f"chatcmpl-{int(time.time())}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": translated},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(text_to_translate.split()),
"completion_tokens": len(translated.split()),
"total_tokens": len(text_to_translate.split()) + len(translated.split())
}
}
except Exception as e:
logger.error(f"❌ {e}")
raise HTTPException(status_code=500, detail=str(e))