Spaces:
Runtime error
Runtime error
File size: 3,261 Bytes
67f97e8 2ab9e4f 67f97e8 aadac55 2ab9e4f 81f57ff aadac55 ac8fce8 2ab9e4f ac8fce8 2ab9e4f ac8fce8 c4c9606 ac8fce8 aadac55 ac8fce8 aadac55 ac8fce8 aadac55 ac8fce8 2ab9e4f ac8fce8 9bb5368 c4c9606 aadac55 9bb5368 aadac55 9bb5368 2ab9e4f 3cab2bd 2ab9e4f aadac55 9bb5368 ac8fce8 9bb5368 ac8fce8 9bb5368 ac8fce8 aadac55 ac8fce8 9bb5368 ac8fce8 9bb5368 ac8fce8 9bb5368 ac8fce8 9bb5368 ac8fce8 aadac55 ac8fce8 9bb5368 2ab9e4f ac8fce8 2ab9e4f ac8fce8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import sys
import time
import json
import logging
import traceback
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from contextlib import asynccontextmanager
from typing import Dict, Any
# ======= ЛОГИРОВАНИЕ =======
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[logging.StreamHandler(sys.stderr)],
)
logger = logging.getLogger("eroha-api")
# ======= ИМПОРТЫ =======
try:
from transformers import pipeline
from langdetect import detect
except Exception as e:
logger.error("[ImportError] transformers/langdetect not available: %s", e, exc_info=True)
pipeline = None
def detect(text): return "en"
# ======= НАСТРОЙКИ =======
HF_HOME = "/tmp/huggingface"
os.environ["HF_HOME"] = HF_HOME
os.makedirs(HF_HOME, exist_ok=True)
# ======= МОДЕЛИ =======
_model_cache: Dict[str, Any] = {}
def get_model(lang: str):
if pipeline is None:
raise RuntimeError("Transformers pipeline is not available")
if lang in _model_cache:
return _model_cache[lang]
model_map = {
"ru": "IlyaGusev/mbart_ru_sum_gazeta",
"en": "facebook/bart-large-cnn",
}
model_name = model_map.get(lang, "facebook/bart-large-cnn")
model = pipeline("summarization", model=model_name, device=-1)
_model_cache[lang] = model
return model
# ======= FastAPI =======
@asynccontextmanager
async def lifespan(app: FastAPI):
start = time.time()
logger.info("[Startup] warming up models...")
for lang in ("en", "ru"):
try: get_model(lang)
except Exception as e: logger.error("Warmup failed: %s", e)
yield
logger.info("[Shutdown] done")
app = FastAPI(title="Eroha Agent API", version="v3.5", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
# ======= МОДЕЛИ ЗАПРОСОВ =======
class SummarizeRequest(BaseModel):
text: str = Field(..., min_length=3, max_length=1_000_000)
class MemoryRequest(BaseModel):
key: str
content: str
# ======= ЭНДПОИНТЫ =======
@app.get("/")
async def root():
return {"status": "ok", "version": "v3.5"}
@app.get("/ping")
async def ping():
return {"status": "healthy", "cache": list(_model_cache.keys())}
@app.post("/summarize")
async def summarize(req: SummarizeRequest):
lang = "ru" if "а" in req.text.lower() else "en"
model = get_model(lang)
result = model(req.text[:2000], max_length=180, min_length=50, do_sample=False)
return {"summary": result[0]["summary_text"].strip(), "lang": lang}
# ======= MEMORY API =======
@app.post("/memorize")
async def memorize(req: MemoryRequest):
with open("memory.json", "a") as f:
f.write(json.dumps(req.dict(), ensure_ascii=False) + "\\n")
return {"status": "saved"}
@app.post("/retrieve")
async def retrieve(req: MemoryRequest):
if not os.path.exists("memory.json"):
return {"found": []}
with open("memory.json", "r") as f:
lines = [json.loads(l) for l in f]
found = [l for l in lines if req.key.lower() in l["key"].lower()]
return {"found": found}
|