Spaces:
Sleeping
Sleeping
File size: 3,741 Bytes
320c4f1 03d7a3b 361d672 03d7a3b 361d672 03d7a3b 5fc69e4 361d672 5fc69e4 03d7a3b a274b81 5fc69e4 361d672 5fc69e4 320c4f1 8cebfe3 57578cb 5fc69e4 57578cb 5fc69e4 320c4f1 5fc69e4 8cebfe3 5fc69e4 8cebfe3 5fc69e4 320c4f1 8cebfe3 5fc69e4 320c4f1 361d672 320c4f1 8cebfe3 5fc69e4 8cebfe3 5fc69e4 320c4f1 361d672 8cebfe3 361d672 320c4f1 5fc69e4 320c4f1 18ac473 5fc69e4 18ac473 320c4f1 18ac473 320c4f1 5fc69e4 18ac473 5fc69e4 9191afb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import asyncio
from contextlib import asynccontextmanager
from pathlib import Path
import markdown
from config import (BASE_DIR, EMBEDDER_MODEL_DIR, EMBEDDER_MODEL_NAME,
FALLBACK_MODEL_DIR, FALLBACK_MODEL_NAME, HF_TOKEN)
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from manager.dialogue_manager import handle_dialogue
from models.model_loader import load_embedder, load_fallback_model
from rag.rag_manager import (add_docs, chroma_initialized,
load_game_docs_from_disk, set_embedder)
from schemas import AskReq, AskRes
templates = Jinja2Templates(directory="templates")
model_ready = False
async def load_models(app: FastAPI):
global model_ready
print("π starting model loading...")
fb_tokenizer, fb_model = load_fallback_model(FALLBACK_MODEL_NAME, FALLBACK_MODEL_DIR, token=HF_TOKEN)
app.state.fallback_tokenizer = fb_tokenizer
app.state.fallback_model = fb_model
embedder = load_embedder(EMBEDDER_MODEL_NAME, EMBEDDER_MODEL_DIR, token=HF_TOKEN)
app.state.embedder = embedder
set_embedder(embedder)
docs_path = BASE_DIR / "rag" / "docs"
if not chroma_initialized():
docs = load_game_docs_from_disk(str(docs_path))
add_docs(docs)
print(f"β
finished inserting {len(docs)} documents into RAG DB")
else:
print("π already initialized RAG DB")
model_ready = True
print("β
model loading complete, server is ready to accept requests")
@asynccontextmanager
async def lifespan(app: FastAPI):
asyncio.create_task(load_models(app))
yield
print("π shutting down...")
app = FastAPI(title="neuro-engine", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["https://fpsgame-rrbc.onrender.com"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", include_in_schema=False)
async def root(request: Request):
md_path = Path(__file__).parent / "README.md"
md_content = md_path.read_text(encoding="utf-8")
start_tag = "<!-- app-tab:start -->"
end_tag = "<!-- app-tab:end -->"
if start_tag in md_content and end_tag in md_content:
short_md = md_content.split(start_tag)[1].split(end_tag)[0].strip()
else:
short_md = md_content # fallback: all content if tags not found
html_from_md = markdown.markdown(short_md)
return templates.TemplateResponse("index.html", {"request": request, "readme_content": html_from_md})
@app.get("/status")
async def status():
return {"ready": model_ready}
@app.post("/wake")
async def wake(request: Request):
session_id = (await request.json()).get("session_id", "unknown")
print(f"π‘ Wake signal received for session: {session_id}")
if not model_ready:
asyncio.create_task(load_models(app))
return {"status": "awake", "model_ready": model_ready}
@app.post("/ask", response_model=AskRes)
async def ask(request: Request, req: AskReq):
if not model_ready:
raise HTTPException(status_code=503, detail="Model not ready")
if not req.context:
raise HTTPException(status_code=400, detail="missing context")
if not (req.session_id and req.npc_id and req.user_input):
raise HTTPException(status_code=400, detail="missing fields")
context = req.context
npc_config_dict = context.npc_config.model_dump() if context.npc_config else None
return await handle_dialogue(
request=request,
session_id=req.session_id,
npc_id=req.npc_id,
user_input=req.user_input,
context=context.model_dump(),
npc_config=npc_config_dict
)
|