File size: 5,980 Bytes
42a6c9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# api/server.py
import os
import time
from typing import Dict, List, Optional, Tuple
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel

from api.config import DEFAULT_COURSE_TOPICS, DEFAULT_MODEL
from api.syllabus_utils import extract_course_topics_from_file
from api.rag_engine import build_rag_chunks_from_file, retrieve_relevant_chunks
from api.clare_core import (
    detect_language,
    chat_with_clare,
    update_weaknesses_from_message,
    update_cognitive_state_from_message,
    render_session_status,
    export_conversation,
    summarize_conversation,
)

MODULE10_PATH = os.path.join(os.path.dirname(__file__), "module10_responsible_ai.pdf")
MODULE10_DOC_TYPE = "Literature Review / Paper"

app = FastAPI(title="Clare API")

# ---- serve web build ----
WEB_DIST = os.path.join(os.path.dirname(__file__), "..", "web", "dist")
app.mount("/assets", StaticFiles(directory=os.path.join(WEB_DIST, "assets")), name="assets")

@app.get("/")
def index():
    return FileResponse(os.path.join(WEB_DIST, "index.html"))

# ---- in-memory session store (MVP) ----
# 生产环境建议 Redis / DB;但你第一阶段完全够用
SESSIONS: Dict[str, Dict] = {}

def _get_session(user_id: str) -> Dict:
    if user_id not in SESSIONS:
        # preload module10
        course_outline = DEFAULT_COURSE_TOPICS
        rag_chunks = []
        if os.path.exists(MODULE10_PATH):
            rag_chunks = build_rag_chunks_from_file(MODULE10_PATH, MODULE10_DOC_TYPE)

        SESSIONS[user_id] = {
            "user_id": user_id,
            "name": "",
            "history": [],
            "weaknesses": [],
            "cognitive_state": {"confusion": 0, "mastery": 0},
            "course_outline": course_outline,
            "rag_chunks": rag_chunks,
            "model_name": DEFAULT_MODEL,
        }
    return SESSIONS[user_id]

class LoginReq(BaseModel):
    name: str
    user_id: str

@app.post("/api/login")
def login(req: LoginReq):
    sess = _get_session(req.user_id)
    sess["name"] = req.name
    return {"ok": True, "user": {"name": req.name, "user_id": req.user_id}}

class ChatReq(BaseModel):
    user_id: str
    message: str
    learning_mode: str
    language_preference: str = "Auto"
    doc_type: str = "Syllabus"

@app.post("/api/chat")
def chat(req: ChatReq):
    sess = _get_session(req.user_id)

    msg = (req.message or "").strip()
    if not msg:
        return {"reply": "", "session_status_md": render_session_status(req.learning_mode, sess["weaknesses"], sess["cognitive_state"]), "refs": []}

    resolved_lang = detect_language(msg, req.language_preference)

    sess["weaknesses"] = update_weaknesses_from_message(msg, sess["weaknesses"])
    sess["cognitive_state"] = update_cognitive_state_from_message(msg, sess["cognitive_state"])

    # academic gating:沿用你 app.py 的 is_academic_query 逻辑(建议后续挪进 api utils)
    rag_context_text, rag_used_chunks = retrieve_relevant_chunks(msg, sess["rag_chunks"])

    start_ts = time.time()
    answer, new_history = chat_with_clare(
        message=msg,
        history=sess["history"],
        model_name=sess["model_name"],
        language_preference=resolved_lang,
        learning_mode=req.learning_mode,
        doc_type=req.doc_type,
        course_outline=sess["course_outline"],
        weaknesses=sess["weaknesses"],
        cognitive_state=sess["cognitive_state"],
        rag_context=rag_context_text,
    )
    latency_ms = (time.time() - start_ts) * 1000.0

    sess["history"] = new_history

    refs = [
        {"source_file": c.get("source_file"), "section": c.get("section")}
        for c in (rag_used_chunks or [])
    ]

    return {
        "reply": answer,
        "session_status_md": render_session_status(req.learning_mode, sess["weaknesses"], sess["cognitive_state"]),
        "refs": refs,
        "latency_ms": latency_ms,
    }

@app.post("/api/upload")
async def upload(
    user_id: str = Form(...),
    doc_type: str = Form(...),
    file: UploadFile = File(...),
):
    sess = _get_session(user_id)

    # 保存到临时文件
    tmp_path = f"/tmp/{file.filename}"
    content = await file.read()
    with open(tmp_path, "wb") as f:
        f.write(content)

    # 更新 topics(仅 syllabus)
    if doc_type == "Syllabus":
        # 复用你 syllabus_utils 的逻辑:它期待 gradio file_obj,有 .name
        class _F: pass
        fo = _F(); fo.name = tmp_path
        sess["course_outline"] = extract_course_topics_from_file(fo, doc_type)

    # 更新 rag chunks(所有文件都可)
    new_chunks = build_rag_chunks_from_file(tmp_path, doc_type)
    sess["rag_chunks"] = (sess["rag_chunks"] or []) + (new_chunks or [])

    status_md = f"✅ Loaded Module 10 base reading + uploaded {doc_type} file."
    return {"ok": True, "added_chunks": len(new_chunks), "status_md": status_md}

class ExportReq(BaseModel):
    user_id: str
    learning_mode: str

@app.post("/api/export")
def api_export(req: ExportReq):
    sess = _get_session(req.user_id)
    md = export_conversation(
        sess["history"],
        sess["course_outline"],
        req.learning_mode,
        sess["weaknesses"],
        sess["cognitive_state"],
    )
    return {"markdown": md}

class SummaryReq(BaseModel):
    user_id: str
    learning_mode: str
    language_preference: str = "Auto"

@app.post("/api/summary")
def api_summary(req: SummaryReq):
    sess = _get_session(req.user_id)
    md = summarize_conversation(
        sess["history"],
        sess["course_outline"],
        sess["weaknesses"],
        sess["cognitive_state"],
        sess["model_name"],
        req.language_preference,
    )
    return {"markdown": md}

@app.get("/api/memoryline")
def memoryline(user_id: str):
    # v1 写死也可以;前端只渲染
    _ = _get_session(user_id)
    return {"next_review_label": "T+7", "progress_pct": 0.4}