SarahXia0405 commited on
Commit
f671bf8
·
verified ·
1 Parent(s): 06efe20

Update api/server.py

Browse files
Files changed (1) hide show
  1. api/server.py +196 -62
api/server.py CHANGED
@@ -1,10 +1,12 @@
1
  # api/server.py
2
  import os
3
  import time
4
- from typing import Dict, List, Optional, Tuple
5
- from fastapi import FastAPI, UploadFile, File, Form
6
- from fastapi.responses import FileResponse
 
7
  from fastapi.staticfiles import StaticFiles
 
8
  from pydantic import BaseModel
9
 
10
  from api.config import DEFAULT_COURSE_TOPICS, DEFAULT_MODEL
@@ -20,52 +22,100 @@ from api.clare_core import (
20
  summarize_conversation,
21
  )
22
 
23
- MODULE10_PATH = os.path.join(os.path.dirname(__file__), "module10_responsible_ai.pdf")
 
 
 
 
 
 
24
  MODULE10_DOC_TYPE = "Literature Review / Paper"
25
 
 
 
 
 
 
 
 
26
  app = FastAPI(title="Clare API")
27
 
28
- # ---- serve web build ----
29
- WEB_DIST = os.path.join(os.path.dirname(__file__), "..", "web", "dist")
30
- app.mount("/assets", StaticFiles(directory=os.path.join(WEB_DIST, "assets")), name="assets")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  @app.get("/")
33
  def index():
34
- return FileResponse(os.path.join(WEB_DIST, "index.html"))
 
 
 
 
 
 
35
 
36
- # ---- in-memory session store (MVP) ----
37
- # 生产环境建议 Redis / DB;但你第一阶段完全够用
 
 
38
  SESSIONS: Dict[str, Dict] = {}
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def _get_session(user_id: str) -> Dict:
41
  if user_id not in SESSIONS:
42
- # preload module10
43
- course_outline = DEFAULT_COURSE_TOPICS
44
- rag_chunks = []
45
- if os.path.exists(MODULE10_PATH):
46
- rag_chunks = build_rag_chunks_from_file(MODULE10_PATH, MODULE10_DOC_TYPE)
47
-
48
  SESSIONS[user_id] = {
49
  "user_id": user_id,
50
  "name": "",
51
  "history": [],
52
  "weaknesses": [],
53
  "cognitive_state": {"confusion": 0, "mastery": 0},
54
- "course_outline": course_outline,
55
- "rag_chunks": rag_chunks,
 
56
  "model_name": DEFAULT_MODEL,
57
  }
58
  return SESSIONS[user_id]
59
 
 
 
 
 
60
  class LoginReq(BaseModel):
61
  name: str
62
  user_id: str
63
 
64
- @app.post("/api/login")
65
- def login(req: LoginReq):
66
- sess = _get_session(req.user_id)
67
- sess["name"] = req.name
68
- return {"ok": True, "user": {"name": req.name, "user_id": req.user_id}}
69
 
70
  class ChatReq(BaseModel):
71
  user_id: str
@@ -74,37 +124,80 @@ class ChatReq(BaseModel):
74
  language_preference: str = "Auto"
75
  doc_type: str = "Syllabus"
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  @app.post("/api/chat")
78
  def chat(req: ChatReq):
79
- sess = _get_session(req.user_id)
80
-
81
  msg = (req.message or "").strip()
 
 
 
 
 
 
82
  if not msg:
83
- return {"reply": "", "session_status_md": render_session_status(req.learning_mode, sess["weaknesses"], sess["cognitive_state"]), "refs": []}
 
 
 
 
 
 
 
84
 
85
  resolved_lang = detect_language(msg, req.language_preference)
86
 
87
  sess["weaknesses"] = update_weaknesses_from_message(msg, sess["weaknesses"])
88
  sess["cognitive_state"] = update_cognitive_state_from_message(msg, sess["cognitive_state"])
89
 
90
- # academic gating:沿用你 app.py 的 is_academic_query 逻辑(建议后续挪进 api utils)
91
  rag_context_text, rag_used_chunks = retrieve_relevant_chunks(msg, sess["rag_chunks"])
92
 
93
  start_ts = time.time()
94
- answer, new_history = chat_with_clare(
95
- message=msg,
96
- history=sess["history"],
97
- model_name=sess["model_name"],
98
- language_preference=resolved_lang,
99
- learning_mode=req.learning_mode,
100
- doc_type=req.doc_type,
101
- course_outline=sess["course_outline"],
102
- weaknesses=sess["weaknesses"],
103
- cognitive_state=sess["cognitive_state"],
104
- rag_context=rag_context_text,
105
- )
106
- latency_ms = (time.time() - start_ts) * 1000.0
 
 
 
107
 
 
108
  sess["history"] = new_history
109
 
110
  refs = [
@@ -114,46 +207,67 @@ def chat(req: ChatReq):
114
 
115
  return {
116
  "reply": answer,
117
- "session_status_md": render_session_status(req.learning_mode, sess["weaknesses"], sess["cognitive_state"]),
 
 
118
  "refs": refs,
119
  "latency_ms": latency_ms,
120
  }
121
 
 
122
  @app.post("/api/upload")
123
  async def upload(
124
  user_id: str = Form(...),
125
  doc_type: str = Form(...),
126
  file: UploadFile = File(...),
127
  ):
 
 
 
 
 
 
 
 
128
  sess = _get_session(user_id)
129
 
130
- # 保存到临时文件
131
  tmp_path = f"/tmp/{file.filename}"
132
  content = await file.read()
133
  with open(tmp_path, "wb") as f:
134
  f.write(content)
135
 
136
- # 更新 topics(仅 syllabus
137
  if doc_type == "Syllabus":
138
- # 复用你 syllabus_utils 的逻辑:它期待 gradio file_obj,有 .name
139
- class _F: pass
140
- fo = _F(); fo.name = tmp_path
141
- sess["course_outline"] = extract_course_topics_from_file(fo, doc_type)
142
 
143
- # 更新 rag chunks(所有文件都可)
144
- new_chunks = build_rag_chunks_from_file(tmp_path, doc_type)
145
- sess["rag_chunks"] = (sess["rag_chunks"] or []) + (new_chunks or [])
 
 
 
146
 
147
- status_md = f"✅ Loaded Module 10 base reading + uploaded {doc_type} file."
 
 
 
 
 
 
 
 
148
  return {"ok": True, "added_chunks": len(new_chunks), "status_md": status_md}
149
 
150
- class ExportReq(BaseModel):
151
- user_id: str
152
- learning_mode: str
153
 
154
  @app.post("/api/export")
155
  def api_export(req: ExportReq):
156
- sess = _get_session(req.user_id)
 
 
 
 
157
  md = export_conversation(
158
  sess["history"],
159
  sess["course_outline"],
@@ -163,14 +277,14 @@ def api_export(req: ExportReq):
163
  )
164
  return {"markdown": md}
165
 
166
- class SummaryReq(BaseModel):
167
- user_id: str
168
- learning_mode: str
169
- language_preference: str = "Auto"
170
 
171
  @app.post("/api/summary")
172
  def api_summary(req: SummaryReq):
173
- sess = _get_session(req.user_id)
 
 
 
 
174
  md = summarize_conversation(
175
  sess["history"],
176
  sess["course_outline"],
@@ -181,8 +295,28 @@ def api_summary(req: SummaryReq):
181
  )
182
  return {"markdown": md}
183
 
 
184
  @app.get("/api/memoryline")
185
  def memoryline(user_id: str):
186
- # v1 写死也可以;前端只渲染
187
- _ = _get_session(user_id)
188
  return {"next_review_label": "T+7", "progress_pct": 0.4}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # api/server.py
2
  import os
3
  import time
4
+ from typing import Dict
5
+
6
+ from fastapi import FastAPI, UploadFile, File, Form, Request
7
+ from fastapi.responses import FileResponse, JSONResponse
8
  from fastapi.staticfiles import StaticFiles
9
+ from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel
11
 
12
  from api.config import DEFAULT_COURSE_TOPICS, DEFAULT_MODEL
 
22
  summarize_conversation,
23
  )
24
 
25
+ # ----------------------------
26
+ # Paths / Constants
27
+ # ----------------------------
28
+ API_DIR = os.path.dirname(__file__)
29
+ REPO_ROOT = os.path.abspath(os.path.join(API_DIR, ".."))
30
+
31
+ MODULE10_PATH = os.path.join(API_DIR, "module10_responsible_ai.pdf")
32
  MODULE10_DOC_TYPE = "Literature Review / Paper"
33
 
34
+ WEB_DIST = os.path.join(REPO_ROOT, "web", "dist")
35
+ WEB_INDEX = os.path.join(WEB_DIST, "index.html")
36
+ WEB_ASSETS = os.path.join(WEB_DIST, "assets")
37
+
38
+ # ----------------------------
39
+ # App
40
+ # ----------------------------
41
  app = FastAPI(title="Clare API")
42
 
43
+ # If later you split FE/BE domain, this prevents CORS headaches.
44
+ # For same-origin (Docker Space) it doesn't hurt.
45
+ app.add_middleware(
46
+ CORSMiddleware,
47
+ allow_origins=["*"],
48
+ allow_credentials=True,
49
+ allow_methods=["*"],
50
+ allow_headers=["*"],
51
+ )
52
+
53
+ # ----------------------------
54
+ # Static hosting (Vite build)
55
+ # ----------------------------
56
+ # Vite build typically outputs: web/dist/index.html + web/dist/assets/*
57
+ # We mount /assets so <script src="/assets/..."> works.
58
+ if os.path.isdir(WEB_ASSETS):
59
+ app.mount("/assets", StaticFiles(directory=WEB_ASSETS), name="assets")
60
+
61
+ # Optionally also serve other static files in dist root (favicon, etc.)
62
+ if os.path.isdir(WEB_DIST):
63
+ app.mount("/static", StaticFiles(directory=WEB_DIST), name="static")
64
+
65
 
66
  @app.get("/")
67
  def index():
68
+ if os.path.exists(WEB_INDEX):
69
+ return FileResponse(WEB_INDEX)
70
+ return JSONResponse(
71
+ {"detail": "web/dist not found. Build frontend first (web/dist/index.html)."},
72
+ status_code=500,
73
+ )
74
+
75
 
76
+ # ----------------------------
77
+ # In-memory session store (MVP)
78
+ # ----------------------------
79
+ # Production -> Redis / DB
80
  SESSIONS: Dict[str, Dict] = {}
81
 
82
+
83
+ def _preload_module10_chunks():
84
+ if os.path.exists(MODULE10_PATH):
85
+ try:
86
+ return build_rag_chunks_from_file(MODULE10_PATH, MODULE10_DOC_TYPE) or []
87
+ except Exception as e:
88
+ print(f"[preload] module10 parse failed: {repr(e)}")
89
+ return []
90
+ return []
91
+
92
+
93
+ MODULE10_CHUNKS_CACHE = _preload_module10_chunks()
94
+
95
+
96
  def _get_session(user_id: str) -> Dict:
97
  if user_id not in SESSIONS:
 
 
 
 
 
 
98
  SESSIONS[user_id] = {
99
  "user_id": user_id,
100
  "name": "",
101
  "history": [],
102
  "weaknesses": [],
103
  "cognitive_state": {"confusion": 0, "mastery": 0},
104
+ "course_outline": DEFAULT_COURSE_TOPICS,
105
+ # preload base reading
106
+ "rag_chunks": list(MODULE10_CHUNKS_CACHE),
107
  "model_name": DEFAULT_MODEL,
108
  }
109
  return SESSIONS[user_id]
110
 
111
+
112
+ # ----------------------------
113
+ # Schemas
114
+ # ----------------------------
115
  class LoginReq(BaseModel):
116
  name: str
117
  user_id: str
118
 
 
 
 
 
 
119
 
120
  class ChatReq(BaseModel):
121
  user_id: str
 
124
  language_preference: str = "Auto"
125
  doc_type: str = "Syllabus"
126
 
127
+
128
+ class ExportReq(BaseModel):
129
+ user_id: str
130
+ learning_mode: str
131
+
132
+
133
+ class SummaryReq(BaseModel):
134
+ user_id: str
135
+ learning_mode: str
136
+ language_preference: str = "Auto"
137
+
138
+
139
+ # ----------------------------
140
+ # API Routes
141
+ # ----------------------------
142
+ @app.post("/api/login")
143
+ def login(req: LoginReq):
144
+ user_id = (req.user_id or "").strip()
145
+ name = (req.name or "").strip()
146
+ if not user_id or not name:
147
+ return JSONResponse({"ok": False, "error": "Missing name/user_id"}, status_code=400)
148
+
149
+ sess = _get_session(user_id)
150
+ sess["name"] = name
151
+ return {"ok": True, "user": {"name": name, "user_id": user_id}}
152
+
153
+
154
  @app.post("/api/chat")
155
  def chat(req: ChatReq):
156
+ user_id = (req.user_id or "").strip()
 
157
  msg = (req.message or "").strip()
158
+
159
+ if not user_id:
160
+ return JSONResponse({"error": "Missing user_id"}, status_code=400)
161
+
162
+ sess = _get_session(user_id)
163
+
164
  if not msg:
165
+ return {
166
+ "reply": "",
167
+ "session_status_md": render_session_status(
168
+ req.learning_mode, sess["weaknesses"], sess["cognitive_state"]
169
+ ),
170
+ "refs": [],
171
+ "latency_ms": 0.0,
172
+ }
173
 
174
  resolved_lang = detect_language(msg, req.language_preference)
175
 
176
  sess["weaknesses"] = update_weaknesses_from_message(msg, sess["weaknesses"])
177
  sess["cognitive_state"] = update_cognitive_state_from_message(msg, sess["cognitive_state"])
178
 
179
+ # RAG
180
  rag_context_text, rag_used_chunks = retrieve_relevant_chunks(msg, sess["rag_chunks"])
181
 
182
  start_ts = time.time()
183
+ try:
184
+ answer, new_history = chat_with_clare(
185
+ message=msg,
186
+ history=sess["history"],
187
+ model_name=sess["model_name"],
188
+ language_preference=resolved_lang,
189
+ learning_mode=req.learning_mode,
190
+ doc_type=req.doc_type,
191
+ course_outline=sess["course_outline"],
192
+ weaknesses=sess["weaknesses"],
193
+ cognitive_state=sess["cognitive_state"],
194
+ rag_context=rag_context_text,
195
+ )
196
+ except Exception as e:
197
+ print(f"[chat] error: {repr(e)}")
198
+ return JSONResponse({"error": f"chat failed: {repr(e)}"}, status_code=500)
199
 
200
+ latency_ms = (time.time() - start_ts) * 1000.0
201
  sess["history"] = new_history
202
 
203
  refs = [
 
207
 
208
  return {
209
  "reply": answer,
210
+ "session_status_md": render_session_status(
211
+ req.learning_mode, sess["weaknesses"], sess["cognitive_state"]
212
+ ),
213
  "refs": refs,
214
  "latency_ms": latency_ms,
215
  }
216
 
217
+
218
  @app.post("/api/upload")
219
  async def upload(
220
  user_id: str = Form(...),
221
  doc_type: str = Form(...),
222
  file: UploadFile = File(...),
223
  ):
224
+ user_id = (user_id or "").strip()
225
+ doc_type = (doc_type or "").strip()
226
+
227
+ if not user_id:
228
+ return JSONResponse({"ok": False, "error": "Missing user_id"}, status_code=400)
229
+ if not file or not file.filename:
230
+ return JSONResponse({"ok": False, "error": "Missing file"}, status_code=400)
231
+
232
  sess = _get_session(user_id)
233
 
234
+ # Save to /tmp
235
  tmp_path = f"/tmp/{file.filename}"
236
  content = await file.read()
237
  with open(tmp_path, "wb") as f:
238
  f.write(content)
239
 
240
+ # Update topics only for syllabus
241
  if doc_type == "Syllabus":
242
+ class _F:
243
+ pass
 
 
244
 
245
+ fo = _F()
246
+ fo.name = tmp_path
247
+ try:
248
+ sess["course_outline"] = extract_course_topics_from_file(fo, doc_type)
249
+ except Exception as e:
250
+ print(f"[upload] syllabus parse error: {repr(e)}")
251
 
252
+ # Update rag chunks for any doc
253
+ try:
254
+ new_chunks = build_rag_chunks_from_file(tmp_path, doc_type) or []
255
+ sess["rag_chunks"] = (sess["rag_chunks"] or []) + new_chunks
256
+ except Exception as e:
257
+ print(f"[upload] rag build error: {repr(e)}")
258
+ new_chunks = []
259
+
260
+ status_md = f"✅ Loaded base reading + uploaded {doc_type} file."
261
  return {"ok": True, "added_chunks": len(new_chunks), "status_md": status_md}
262
 
 
 
 
263
 
264
  @app.post("/api/export")
265
  def api_export(req: ExportReq):
266
+ user_id = (req.user_id or "").strip()
267
+ if not user_id:
268
+ return JSONResponse({"error": "Missing user_id"}, status_code=400)
269
+
270
+ sess = _get_session(user_id)
271
  md = export_conversation(
272
  sess["history"],
273
  sess["course_outline"],
 
277
  )
278
  return {"markdown": md}
279
 
 
 
 
 
280
 
281
  @app.post("/api/summary")
282
  def api_summary(req: SummaryReq):
283
+ user_id = (req.user_id or "").strip()
284
+ if not user_id:
285
+ return JSONResponse({"error": "Missing user_id"}, status_code=400)
286
+
287
+ sess = _get_session(user_id)
288
  md = summarize_conversation(
289
  sess["history"],
290
  sess["course_outline"],
 
295
  )
296
  return {"markdown": md}
297
 
298
+
299
  @app.get("/api/memoryline")
300
  def memoryline(user_id: str):
301
+ _ = _get_session((user_id or "").strip())
302
+ # v1: 写死也没问题;前端只渲染
303
  return {"next_review_label": "T+7", "progress_pct": 0.4}
304
+
305
+
306
+ # ----------------------------
307
+ # SPA Fallback (important!)
308
+ # ----------------------------
309
+ # If user refreshes /some/route, FE router needs index.html.
310
+ @app.get("/{full_path:path}")
311
+ def spa_fallback(full_path: str, request: Request):
312
+ # Do not hijack API/static paths
313
+ if full_path.startswith("api/") or full_path.startswith("assets/") or full_path.startswith("static/"):
314
+ return JSONResponse({"detail": "Not Found"}, status_code=404)
315
+
316
+ if os.path.exists(WEB_INDEX):
317
+ return FileResponse(WEB_INDEX)
318
+
319
+ return JSONResponse(
320
+ {"detail": "web/dist not found. Build frontend first."},
321
+ status_code=500,
322
+ )