Spaces:
Running
Running
| import os | |
| import sys | |
| import time | |
| import json | |
| import logging | |
| import traceback | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from contextlib import asynccontextmanager | |
| from typing import Dict, Any | |
| # ======= ЛОГИРОВАНИЕ ======= | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| handlers=[logging.StreamHandler(sys.stderr)], | |
| ) | |
| logger = logging.getLogger("eroha-api") | |
| # ======= ИМПОРТЫ ======= | |
| try: | |
| from transformers import pipeline | |
| from langdetect import detect | |
| except Exception as e: | |
| logger.error("[ImportError] transformers/langdetect not available: %s", e, exc_info=True) | |
| pipeline = None | |
| def detect(text): return "en" | |
| # ======= НАСТРОЙКИ ======= | |
| HF_HOME = "/tmp/huggingface" | |
| os.environ["HF_HOME"] = HF_HOME | |
| os.makedirs(HF_HOME, exist_ok=True) | |
| # ======= МОДЕЛИ ======= | |
| _model_cache: Dict[str, Any] = {} | |
| def get_model(lang: str): | |
| if pipeline is None: | |
| raise RuntimeError("Transformers pipeline is not available") | |
| if lang in _model_cache: | |
| return _model_cache[lang] | |
| model_map = { | |
| "ru": "IlyaGusev/mbart_ru_sum_gazeta", | |
| "en": "facebook/bart-large-cnn", | |
| } | |
| model_name = model_map.get(lang, "facebook/bart-large-cnn") | |
| model = pipeline("summarization", model=model_name, device=-1) | |
| _model_cache[lang] = model | |
| return model | |
| # ======= FastAPI ======= | |
| async def lifespan(app: FastAPI): | |
| start = time.time() | |
| logger.info("[Startup] warming up models...") | |
| for lang in ("en", "ru"): | |
| try: get_model(lang) | |
| except Exception as e: logger.error("Warmup failed: %s", e) | |
| yield | |
| logger.info("[Shutdown] done") | |
| app = FastAPI(title="Eroha Agent API", version="v3.5", lifespan=lifespan) | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| # ======= МОДЕЛИ ЗАПРОСОВ ======= | |
| class SummarizeRequest(BaseModel): | |
| text: str = Field(..., min_length=3, max_length=1_000_000) | |
| class MemoryRequest(BaseModel): | |
| key: str | |
| content: str | |
| # ======= ЭНДПОИНТЫ ======= | |
| async def root(): | |
| return {"status": "ok", "version": "v3.5"} | |
| async def ping(): | |
| return {"status": "healthy", "cache": list(_model_cache.keys())} | |
| async def summarize(req: SummarizeRequest): | |
| lang = "ru" if "а" in req.text.lower() else "en" | |
| model = get_model(lang) | |
| result = model(req.text[:2000], max_length=180, min_length=50, do_sample=False) | |
| return {"summary": result[0]["summary_text"].strip(), "lang": lang} | |
| # ======= MEMORY API ======= | |
| async def memorize(req: MemoryRequest): | |
| with open("memory.json", "a") as f: | |
| f.write(json.dumps(req.dict(), ensure_ascii=False) + "\\n") | |
| return {"status": "saved"} | |
| async def retrieve(req: MemoryRequest): | |
| if not os.path.exists("memory.json"): | |
| return {"found": []} | |
| with open("memory.json", "r") as f: | |
| lines = [json.loads(l) for l in f] | |
| found = [l for l in lines if req.key.lower() in l["key"].lower()] | |
| return {"found": found} | |