File size: 3,261 Bytes
67f97e8
2ab9e4f
67f97e8
aadac55
2ab9e4f
 
 
81f57ff
aadac55
ac8fce8
 
2ab9e4f
ac8fce8
2ab9e4f
 
 
 
 
ac8fce8
c4c9606
ac8fce8
aadac55
 
 
ac8fce8
aadac55
ac8fce8
 
aadac55
ac8fce8
2ab9e4f
 
 
 
ac8fce8
9bb5368
c4c9606
aadac55
9bb5368
aadac55
9bb5368
 
2ab9e4f
 
 
3cab2bd
2ab9e4f
aadac55
 
 
9bb5368
ac8fce8
9bb5368
ac8fce8
9bb5368
ac8fce8
 
 
 
aadac55
ac8fce8
9bb5368
ac8fce8
 
9bb5368
ac8fce8
 
 
9bb5368
ac8fce8
 
 
9bb5368
ac8fce8
aadac55
 
ac8fce8
9bb5368
2ab9e4f
 
ac8fce8
2ab9e4f
 
 
ac8fce8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import sys
import time
import json
import logging
import traceback
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from contextlib import asynccontextmanager
from typing import Dict, Any

# ======= ЛОГИРОВАНИЕ =======
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[logging.StreamHandler(sys.stderr)],
)
logger = logging.getLogger("eroha-api")

# ======= ИМПОРТЫ =======
try:
    from transformers import pipeline
    from langdetect import detect
except Exception as e:
    logger.error("[ImportError] transformers/langdetect not available: %s", e, exc_info=True)
    pipeline = None
    def detect(text): return "en"

# ======= НАСТРОЙКИ =======
HF_HOME = "/tmp/huggingface"
os.environ["HF_HOME"] = HF_HOME
os.makedirs(HF_HOME, exist_ok=True)

# ======= МОДЕЛИ =======
_model_cache: Dict[str, Any] = {}

def get_model(lang: str):
    if pipeline is None:
        raise RuntimeError("Transformers pipeline is not available")
    if lang in _model_cache:
        return _model_cache[lang]
    model_map = {
        "ru": "IlyaGusev/mbart_ru_sum_gazeta",
        "en": "facebook/bart-large-cnn",
    }
    model_name = model_map.get(lang, "facebook/bart-large-cnn")
    model = pipeline("summarization", model=model_name, device=-1)
    _model_cache[lang] = model
    return model

# ======= FastAPI =======
@asynccontextmanager
async def lifespan(app: FastAPI):
    start = time.time()
    logger.info("[Startup] warming up models...")
    for lang in ("en", "ru"):
        try: get_model(lang)
        except Exception as e: logger.error("Warmup failed: %s", e)
    yield
    logger.info("[Shutdown] done")

app = FastAPI(title="Eroha Agent API", version="v3.5", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

# ======= МОДЕЛИ ЗАПРОСОВ =======
class SummarizeRequest(BaseModel):
    text: str = Field(..., min_length=3, max_length=1_000_000)

class MemoryRequest(BaseModel):
    key: str
    content: str

# ======= ЭНДПОИНТЫ =======
@app.get("/")
async def root():
    return {"status": "ok", "version": "v3.5"}

@app.get("/ping")
async def ping():
    return {"status": "healthy", "cache": list(_model_cache.keys())}

@app.post("/summarize")
async def summarize(req: SummarizeRequest):
    lang = "ru" if "а" in req.text.lower() else "en"
    model = get_model(lang)
    result = model(req.text[:2000], max_length=180, min_length=50, do_sample=False)
    return {"summary": result[0]["summary_text"].strip(), "lang": lang}

# ======= MEMORY API =======
@app.post("/memorize")
async def memorize(req: MemoryRequest):
    with open("memory.json", "a") as f:
        f.write(json.dumps(req.dict(), ensure_ascii=False) + "\\n")
    return {"status": "saved"}

@app.post("/retrieve")
async def retrieve(req: MemoryRequest):
    if not os.path.exists("memory.json"):
        return {"found": []}
    with open("memory.json", "r") as f:
        lines = [json.loads(l) for l in f]
    found = [l for l in lines if req.key.lower() in l["key"].lower()]
    return {"found": found}