Spaces:
Sleeping
Sleeping
File size: 5,980 Bytes
42a6c9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# api/server.py
import os
import time
from typing import Dict, List, Optional, Tuple
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from api.config import DEFAULT_COURSE_TOPICS, DEFAULT_MODEL
from api.syllabus_utils import extract_course_topics_from_file
from api.rag_engine import build_rag_chunks_from_file, retrieve_relevant_chunks
from api.clare_core import (
detect_language,
chat_with_clare,
update_weaknesses_from_message,
update_cognitive_state_from_message,
render_session_status,
export_conversation,
summarize_conversation,
)
MODULE10_PATH = os.path.join(os.path.dirname(__file__), "module10_responsible_ai.pdf")
MODULE10_DOC_TYPE = "Literature Review / Paper"
app = FastAPI(title="Clare API")
# ---- serve web build ----
WEB_DIST = os.path.join(os.path.dirname(__file__), "..", "web", "dist")
app.mount("/assets", StaticFiles(directory=os.path.join(WEB_DIST, "assets")), name="assets")
@app.get("/")
def index():
return FileResponse(os.path.join(WEB_DIST, "index.html"))
# ---- in-memory session store (MVP) ----
# 生产环境建议 Redis / DB;但你第一阶段完全够用
SESSIONS: Dict[str, Dict] = {}
def _get_session(user_id: str) -> Dict:
if user_id not in SESSIONS:
# preload module10
course_outline = DEFAULT_COURSE_TOPICS
rag_chunks = []
if os.path.exists(MODULE10_PATH):
rag_chunks = build_rag_chunks_from_file(MODULE10_PATH, MODULE10_DOC_TYPE)
SESSIONS[user_id] = {
"user_id": user_id,
"name": "",
"history": [],
"weaknesses": [],
"cognitive_state": {"confusion": 0, "mastery": 0},
"course_outline": course_outline,
"rag_chunks": rag_chunks,
"model_name": DEFAULT_MODEL,
}
return SESSIONS[user_id]
class LoginReq(BaseModel):
name: str
user_id: str
@app.post("/api/login")
def login(req: LoginReq):
sess = _get_session(req.user_id)
sess["name"] = req.name
return {"ok": True, "user": {"name": req.name, "user_id": req.user_id}}
class ChatReq(BaseModel):
user_id: str
message: str
learning_mode: str
language_preference: str = "Auto"
doc_type: str = "Syllabus"
@app.post("/api/chat")
def chat(req: ChatReq):
sess = _get_session(req.user_id)
msg = (req.message or "").strip()
if not msg:
return {"reply": "", "session_status_md": render_session_status(req.learning_mode, sess["weaknesses"], sess["cognitive_state"]), "refs": []}
resolved_lang = detect_language(msg, req.language_preference)
sess["weaknesses"] = update_weaknesses_from_message(msg, sess["weaknesses"])
sess["cognitive_state"] = update_cognitive_state_from_message(msg, sess["cognitive_state"])
# academic gating:沿用你 app.py 的 is_academic_query 逻辑(建议后续挪进 api utils)
rag_context_text, rag_used_chunks = retrieve_relevant_chunks(msg, sess["rag_chunks"])
start_ts = time.time()
answer, new_history = chat_with_clare(
message=msg,
history=sess["history"],
model_name=sess["model_name"],
language_preference=resolved_lang,
learning_mode=req.learning_mode,
doc_type=req.doc_type,
course_outline=sess["course_outline"],
weaknesses=sess["weaknesses"],
cognitive_state=sess["cognitive_state"],
rag_context=rag_context_text,
)
latency_ms = (time.time() - start_ts) * 1000.0
sess["history"] = new_history
refs = [
{"source_file": c.get("source_file"), "section": c.get("section")}
for c in (rag_used_chunks or [])
]
return {
"reply": answer,
"session_status_md": render_session_status(req.learning_mode, sess["weaknesses"], sess["cognitive_state"]),
"refs": refs,
"latency_ms": latency_ms,
}
@app.post("/api/upload")
async def upload(
user_id: str = Form(...),
doc_type: str = Form(...),
file: UploadFile = File(...),
):
sess = _get_session(user_id)
# 保存到临时文件
tmp_path = f"/tmp/{file.filename}"
content = await file.read()
with open(tmp_path, "wb") as f:
f.write(content)
# 更新 topics(仅 syllabus)
if doc_type == "Syllabus":
# 复用你 syllabus_utils 的逻辑:它期待 gradio file_obj,有 .name
class _F: pass
fo = _F(); fo.name = tmp_path
sess["course_outline"] = extract_course_topics_from_file(fo, doc_type)
# 更新 rag chunks(所有文件都可)
new_chunks = build_rag_chunks_from_file(tmp_path, doc_type)
sess["rag_chunks"] = (sess["rag_chunks"] or []) + (new_chunks or [])
status_md = f"✅ Loaded Module 10 base reading + uploaded {doc_type} file."
return {"ok": True, "added_chunks": len(new_chunks), "status_md": status_md}
class ExportReq(BaseModel):
user_id: str
learning_mode: str
@app.post("/api/export")
def api_export(req: ExportReq):
sess = _get_session(req.user_id)
md = export_conversation(
sess["history"],
sess["course_outline"],
req.learning_mode,
sess["weaknesses"],
sess["cognitive_state"],
)
return {"markdown": md}
class SummaryReq(BaseModel):
user_id: str
learning_mode: str
language_preference: str = "Auto"
@app.post("/api/summary")
def api_summary(req: SummaryReq):
sess = _get_session(req.user_id)
md = summarize_conversation(
sess["history"],
sess["course_outline"],
sess["weaknesses"],
sess["cognitive_state"],
sess["model_name"],
req.language_preference,
)
return {"markdown": md}
@app.get("/api/memoryline")
def memoryline(user_id: str):
# v1 写死也可以;前端只渲染
_ = _get_session(user_id)
return {"next_review_label": "T+7", "progress_pct": 0.4}
|