eroha-agentapi / app.py
Yermek68's picture
Update app.py
ac8fce8 verified
raw
history blame
3.26 kB
import os
import sys
import time
import json
import logging
import traceback
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from contextlib import asynccontextmanager
from typing import Dict, Any
# ======= ЛОГИРОВАНИЕ =======
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler(sys.stderr)],
)
logger = logging.getLogger("eroha-api")
# ======= ИМПОРТЫ =======
try:
from transformers import pipeline
from langdetect import detect
except Exception as e:
logger.error("[ImportError] transformers/langdetect not available: %s", e, exc_info=True)
pipeline = None
def detect(text): return "en"
# ======= НАСТРОЙКИ =======
HF_HOME = "/tmp/huggingface"
os.environ["HF_HOME"] = HF_HOME
os.makedirs(HF_HOME, exist_ok=True)
# ======= МОДЕЛИ =======
_model_cache: Dict[str, Any] = {}
def get_model(lang: str):
if pipeline is None:
raise RuntimeError("Transformers pipeline is not available")
if lang in _model_cache:
return _model_cache[lang]
model_map = {
"ru": "IlyaGusev/mbart_ru_sum_gazeta",
"en": "facebook/bart-large-cnn",
}
model_name = model_map.get(lang, "facebook/bart-large-cnn")
model = pipeline("summarization", model=model_name, device=-1)
_model_cache[lang] = model
return model
# ======= FastAPI =======
@asynccontextmanager
async def lifespan(app: FastAPI):
start = time.time()
logger.info("[Startup] warming up models...")
for lang in ("en", "ru"):
try: get_model(lang)
except Exception as e: logger.error("Warmup failed: %s", e)
yield
logger.info("[Shutdown] done")
app = FastAPI(title="Eroha Agent API", version="v3.5", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# ======= МОДЕЛИ ЗАПРОСОВ =======
class SummarizeRequest(BaseModel):
text: str = Field(..., min_length=3, max_length=1_000_000)
class MemoryRequest(BaseModel):
key: str
content: str
# ======= ЭНДПОИНТЫ =======
@app.get("/")
async def root():
return {"status": "ok", "version": "v3.5"}
@app.get("/ping")
async def ping():
return {"status": "healthy", "cache": list(_model_cache.keys())}
@app.post("/summarize")
async def summarize(req: SummarizeRequest):
lang = "ru" if "а" in req.text.lower() else "en"
model = get_model(lang)
result = model(req.text[:2000], max_length=180, min_length=50, do_sample=False)
return {"summary": result[0]["summary_text"].strip(), "lang": lang}
# ======= MEMORY API =======
@app.post("/memorize")
async def memorize(req: MemoryRequest):
with open("memory.json", "a") as f:
f.write(json.dumps(req.dict(), ensure_ascii=False) + "\\n")
return {"status": "saved"}
@app.post("/retrieve")
async def retrieve(req: MemoryRequest):
if not os.path.exists("memory.json"):
return {"found": []}
with open("memory.json", "r") as f:
lines = [json.loads(l) for l in f]
found = [l for l in lines if req.key.lower() in l["key"].lower()]
return {"found": found}