SarahXia0405 commited on
Commit
42a6c9d
·
verified ·
1 Parent(s): 1cf12a9

Create api/server.py

Browse files
Files changed (1) hide show
  1. api/api/server.py +188 -0
api/api/server.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api/server.py
2
+ import os
3
+ import time
4
+ from typing import Dict, List, Optional, Tuple
5
+ from fastapi import FastAPI, UploadFile, File, Form
6
+ from fastapi.responses import FileResponse
7
+ from fastapi.staticfiles import StaticFiles
8
+ from pydantic import BaseModel
9
+
10
+ from api.config import DEFAULT_COURSE_TOPICS, DEFAULT_MODEL
11
+ from api.syllabus_utils import extract_course_topics_from_file
12
+ from api.rag_engine import build_rag_chunks_from_file, retrieve_relevant_chunks
13
+ from api.clare_core import (
14
+ detect_language,
15
+ chat_with_clare,
16
+ update_weaknesses_from_message,
17
+ update_cognitive_state_from_message,
18
+ render_session_status,
19
+ export_conversation,
20
+ summarize_conversation,
21
+ )
22
+
23
+ MODULE10_PATH = os.path.join(os.path.dirname(__file__), "module10_responsible_ai.pdf")
24
+ MODULE10_DOC_TYPE = "Literature Review / Paper"
25
+
26
+ app = FastAPI(title="Clare API")
27
+
28
+ # ---- serve web build ----
29
+ WEB_DIST = os.path.join(os.path.dirname(__file__), "..", "web", "dist")
30
+ app.mount("/assets", StaticFiles(directory=os.path.join(WEB_DIST, "assets")), name="assets")
31
+
32
+ @app.get("/")
33
+ def index():
34
+ return FileResponse(os.path.join(WEB_DIST, "index.html"))
35
+
36
+ # ---- in-memory session store (MVP) ----
37
+ # 生产环境建议 Redis / DB;但你第一阶段完全够用
38
+ SESSIONS: Dict[str, Dict] = {}
39
+
40
+ def _get_session(user_id: str) -> Dict:
41
+ if user_id not in SESSIONS:
42
+ # preload module10
43
+ course_outline = DEFAULT_COURSE_TOPICS
44
+ rag_chunks = []
45
+ if os.path.exists(MODULE10_PATH):
46
+ rag_chunks = build_rag_chunks_from_file(MODULE10_PATH, MODULE10_DOC_TYPE)
47
+
48
+ SESSIONS[user_id] = {
49
+ "user_id": user_id,
50
+ "name": "",
51
+ "history": [],
52
+ "weaknesses": [],
53
+ "cognitive_state": {"confusion": 0, "mastery": 0},
54
+ "course_outline": course_outline,
55
+ "rag_chunks": rag_chunks,
56
+ "model_name": DEFAULT_MODEL,
57
+ }
58
+ return SESSIONS[user_id]
59
+
60
+ class LoginReq(BaseModel):
61
+ name: str
62
+ user_id: str
63
+
64
+ @app.post("/api/login")
65
+ def login(req: LoginReq):
66
+ sess = _get_session(req.user_id)
67
+ sess["name"] = req.name
68
+ return {"ok": True, "user": {"name": req.name, "user_id": req.user_id}}
69
+
70
+ class ChatReq(BaseModel):
71
+ user_id: str
72
+ message: str
73
+ learning_mode: str
74
+ language_preference: str = "Auto"
75
+ doc_type: str = "Syllabus"
76
+
77
+ @app.post("/api/chat")
78
+ def chat(req: ChatReq):
79
+ sess = _get_session(req.user_id)
80
+
81
+ msg = (req.message or "").strip()
82
+ if not msg:
83
+ return {"reply": "", "session_status_md": render_session_status(req.learning_mode, sess["weaknesses"], sess["cognitive_state"]), "refs": []}
84
+
85
+ resolved_lang = detect_language(msg, req.language_preference)
86
+
87
+ sess["weaknesses"] = update_weaknesses_from_message(msg, sess["weaknesses"])
88
+ sess["cognitive_state"] = update_cognitive_state_from_message(msg, sess["cognitive_state"])
89
+
90
+ # academic gating:沿用你 app.py 的 is_academic_query 逻辑(建议后续挪进 api utils)
91
+ rag_context_text, rag_used_chunks = retrieve_relevant_chunks(msg, sess["rag_chunks"])
92
+
93
+ start_ts = time.time()
94
+ answer, new_history = chat_with_clare(
95
+ message=msg,
96
+ history=sess["history"],
97
+ model_name=sess["model_name"],
98
+ language_preference=resolved_lang,
99
+ learning_mode=req.learning_mode,
100
+ doc_type=req.doc_type,
101
+ course_outline=sess["course_outline"],
102
+ weaknesses=sess["weaknesses"],
103
+ cognitive_state=sess["cognitive_state"],
104
+ rag_context=rag_context_text,
105
+ )
106
+ latency_ms = (time.time() - start_ts) * 1000.0
107
+
108
+ sess["history"] = new_history
109
+
110
+ refs = [
111
+ {"source_file": c.get("source_file"), "section": c.get("section")}
112
+ for c in (rag_used_chunks or [])
113
+ ]
114
+
115
+ return {
116
+ "reply": answer,
117
+ "session_status_md": render_session_status(req.learning_mode, sess["weaknesses"], sess["cognitive_state"]),
118
+ "refs": refs,
119
+ "latency_ms": latency_ms,
120
+ }
121
+
122
+ @app.post("/api/upload")
123
+ async def upload(
124
+ user_id: str = Form(...),
125
+ doc_type: str = Form(...),
126
+ file: UploadFile = File(...),
127
+ ):
128
+ sess = _get_session(user_id)
129
+
130
+ # 保存到临时文件
131
+ tmp_path = f"/tmp/{file.filename}"
132
+ content = await file.read()
133
+ with open(tmp_path, "wb") as f:
134
+ f.write(content)
135
+
136
+ # 更新 topics(仅 syllabus)
137
+ if doc_type == "Syllabus":
138
+ # 复用你 syllabus_utils 的逻辑:它期待 gradio file_obj,有 .name
139
+ class _F: pass
140
+ fo = _F(); fo.name = tmp_path
141
+ sess["course_outline"] = extract_course_topics_from_file(fo, doc_type)
142
+
143
+ # 更新 rag chunks(所有文件都可)
144
+ new_chunks = build_rag_chunks_from_file(tmp_path, doc_type)
145
+ sess["rag_chunks"] = (sess["rag_chunks"] or []) + (new_chunks or [])
146
+
147
+ status_md = f"✅ Loaded Module 10 base reading + uploaded {doc_type} file."
148
+ return {"ok": True, "added_chunks": len(new_chunks), "status_md": status_md}
149
+
150
+ class ExportReq(BaseModel):
151
+ user_id: str
152
+ learning_mode: str
153
+
154
+ @app.post("/api/export")
155
+ def api_export(req: ExportReq):
156
+ sess = _get_session(req.user_id)
157
+ md = export_conversation(
158
+ sess["history"],
159
+ sess["course_outline"],
160
+ req.learning_mode,
161
+ sess["weaknesses"],
162
+ sess["cognitive_state"],
163
+ )
164
+ return {"markdown": md}
165
+
166
+ class SummaryReq(BaseModel):
167
+ user_id: str
168
+ learning_mode: str
169
+ language_preference: str = "Auto"
170
+
171
+ @app.post("/api/summary")
172
+ def api_summary(req: SummaryReq):
173
+ sess = _get_session(req.user_id)
174
+ md = summarize_conversation(
175
+ sess["history"],
176
+ sess["course_outline"],
177
+ sess["weaknesses"],
178
+ sess["cognitive_state"],
179
+ sess["model_name"],
180
+ req.language_preference,
181
+ )
182
+ return {"markdown": md}
183
+
184
+ @app.get("/api/memoryline")
185
+ def memoryline(user_id: str):
186
+ # v1 写死也可以;前端只渲染
187
+ _ = _get_session(user_id)
188
+ return {"next_review_label": "T+7", "progress_pct": 0.4}