Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from hf_hub_ctranslate2 import TranslatorCT2fromHfHub | |
| from transformers import AutoTokenizer | |
| import time | |
| import re | |
| import logging | |
| from typing import Optional | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title="Helsinki CTranslate2 Translation API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ============================================================================= | |
| # Кэш моделей — загружаем один раз | |
| # ============================================================================= | |
| model_cache = {} | |
| # Маппинг пар языков на CTranslate2 модели на HuggingFace | |
| # Используем уже сконвертированные модели от gaudi/ | |
| CT2_MODELS = { | |
| "en-ru": "Helsinki-NLP/opus-mt-en-ru", # будем конвертировать на лету | |
| "ru-en": "Helsinki-NLP/opus-mt-ru-en", | |
| "en-uk": "Helsinki-NLP/opus-mt-en-uk", | |
| "en-de": "Helsinki-NLP/opus-mt-en-de", | |
| "en-fr": "Helsinki-NLP/opus-mt-en-ROMANCE", | |
| "en-zh": "Helsinki-NLP/opus-mt-en-zh", | |
| } | |
| def get_model(src_lang: str, tgt_lang: str): | |
| key = f"{src_lang}-{tgt_lang}" | |
| if key not in model_cache: | |
| logger.info(f"⏳ Загружаем модель: {key}") | |
| # Пробуем взять готовую CTranslate2 модель | |
| ct2_model_name = f"gaudi/opus-mt-{src_lang}-{tgt_lang}-ctranslate2" | |
| hf_model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}" | |
| try: | |
| # Сначала пробуем готовую CT2 модель от gaudi | |
| tokenizer = AutoTokenizer.from_pretrained(ct2_model_name) | |
| model = TranslatorCT2fromHfHub( | |
| model_name_or_path=ct2_model_name, | |
| device="cpu", | |
| compute_type="int8", # INT8 — максимальная скорость на CPU | |
| tokenizer=tokenizer | |
| ) | |
| logger.info(f"✅ Загружена CTranslate2 модель: {ct2_model_name}") | |
| except Exception as e: | |
| logger.warning(f"CT2 модель не найдена ({e}), конвертируем из Helsinki...") | |
| # Конвертируем на лету из оригинальной Helsinki модели | |
| import ctranslate2 | |
| import subprocess, os | |
| converted_path = f"/tmp/ct2_{src_lang}_{tgt_lang}" | |
| if not os.path.exists(converted_path): | |
| subprocess.run([ | |
| "ct2-transformers-converter", | |
| "--model", hf_model_name, | |
| "--output_dir", converted_path, | |
| "--quantization", "int8", | |
| "--force" | |
| ], check=True) | |
| tokenizer = AutoTokenizer.from_pretrained(hf_model_name) | |
| model = ctranslate2.Translator( | |
| converted_path, | |
| device="cpu", | |
| compute_type="int8", | |
| inter_threads=2, | |
| intra_threads=4, | |
| ) | |
| # Оборачиваем в простой класс для единого интерфейса | |
| model = _CT2Wrapper(model, tokenizer) | |
| logger.info(f"✅ Сконвертирована и загружена: {hf_model_name}") | |
| model_cache[key] = model | |
| return model_cache[key] | |
| class _CT2Wrapper: | |
| """Обёртка над ctranslate2.Translator для единого интерфейса""" | |
| def __init__(self, translator, tokenizer): | |
| self.translator = translator | |
| self.tokenizer = tokenizer | |
| def generate(self, text: list[str]) -> list[str]: | |
| results = [] | |
| for t in text: | |
| tokens = self.tokenizer.convert_ids_to_tokens( | |
| self.tokenizer.encode(t) | |
| ) | |
| out = self.translator.translate_batch([tokens]) | |
| decoded = self.tokenizer.decode( | |
| self.tokenizer.convert_tokens_to_ids(out[0].hypotheses[0]), | |
| skip_special_tokens=True | |
| ) | |
| results.append(decoded) | |
| return results | |
| def split_text(text: str, max_chars: int = 400) -> list[str]: | |
| """Разбивает длинный текст на части по предложениям""" | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| chunks, current = [], "" | |
| for s in sentences: | |
| if len(current) + len(s) < max_chars: | |
| current += s + " " | |
| else: | |
| if current.strip(): | |
| chunks.append(current.strip()) | |
| current = s + " " | |
| if current.strip(): | |
| chunks.append(current.strip()) | |
| return chunks or [text] | |
| def do_translate(text: str, src_lang: str, tgt_lang: str) -> str: | |
| if not text.strip(): | |
| return text | |
| model = get_model(src_lang, tgt_lang) | |
| chunks = split_text(text) | |
| # Батч перевод — CTranslate2 очень эффективен на батчах | |
| translated = model.generate(chunks) | |
| return " ".join(translated) | |
| # ============================================================================= | |
| # ENDPOINTS | |
| # ============================================================================= | |
| def root(): | |
| return { | |
| "message": "Helsinki CTranslate2 API — 6-10x быстрее обычного Helsinki", | |
| "engine": "CTranslate2 INT8", | |
| "docs": "/docs" | |
| } | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "engine": "CTranslate2", | |
| "compute_type": "int8", | |
| "loaded_models": list(model_cache.keys()) | |
| } | |
| def list_models(): | |
| return { | |
| "object": "list", | |
| "data": [{"id": "helsinki-en-ru", "object": "model"}] | |
| } | |
| class TranslateRequest(BaseModel): | |
| text: str | |
| src_lang: str = "en" | |
| tgt_lang: str = "ru" | |
| def translate(req: TranslateRequest): | |
| if not req.text.strip(): | |
| raise HTTPException(status_code=400, detail="Пустой текст") | |
| t0 = time.time() | |
| result = do_translate(req.text, req.src_lang, req.tgt_lang) | |
| elapsed = round(time.time() - t0, 2) | |
| return { | |
| "translated_text": result, | |
| "src_lang": req.src_lang, | |
| "tgt_lang": req.tgt_lang, | |
| "elapsed_sec": elapsed, | |
| "engine": "ctranslate2-int8" | |
| } | |
| def translate_batch( | |
| texts: list[str], | |
| src_lang: str = "en", | |
| tgt_lang: str = "ru" | |
| ): | |
| if not texts: | |
| raise HTTPException(status_code=400, detail="Пустой список") | |
| if len(texts) > 200: | |
| raise HTTPException(status_code=400, detail="Максимум 200 текстов") | |
| t0 = time.time() | |
| model = get_model(src_lang, tgt_lang) | |
| results = model.generate(texts) | |
| elapsed = round(time.time() - t0, 2) | |
| return { | |
| "translations": results, | |
| "count": len(results), | |
| "elapsed_sec": elapsed, | |
| "engine": "ctranslate2-int8" | |
| } | |
| # ============================================================================= | |
| # OpenAI-совместимый endpoint для pdf2zh | |
| # ============================================================================= | |
| async def openai_chat(request: Request): | |
| try: | |
| body = await request.json() | |
| messages = body.get("messages", []) | |
| model_name = body.get("model", "helsinki-en-ru") | |
| parts = model_name.split("-") | |
| src_lang = parts[-2] if len(parts) >= 3 and len(parts[-2]) == 2 else "en" | |
| tgt_lang = parts[-1] if len(parts) >= 3 and len(parts[-1]) == 2 else "ru" | |
| raw_content = next( | |
| (m.get("content", "") for m in messages if m.get("role") == "user"), "" | |
| ) | |
| # Извлекаем текст между "Source text:" и "Translated text:" | |
| match = re.search( | |
| r'(?:Source text)[:\s]*\n?(.*?)(?:\n?(?:Translated text)[:\s]*)', | |
| raw_content, re.IGNORECASE | re.DOTALL | |
| ) | |
| text_to_translate = match.group(1).strip() if match else raw_content | |
| logger.info(f"🔤 [{src_lang}→{tgt_lang}]: {text_to_translate[:100]}") | |
| translated = do_translate(text_to_translate, src_lang, tgt_lang) | |
| logger.info(f"✅ {translated[:100]}") | |
| return { | |
| "id": f"chatcmpl-{int(time.time())}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": model_name, | |
| "choices": [{ | |
| "index": 0, | |
| "message": {"role": "assistant", "content": translated}, | |
| "finish_reason": "stop" | |
| }], | |
| "usage": { | |
| "prompt_tokens": len(text_to_translate.split()), | |
| "completion_tokens": len(translated.split()), | |
| "total_tokens": len(text_to_translate.split()) + len(translated.split()) | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"❌ {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |