File size: 8,772 Bytes
594ed40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
from fastapi import APIRouter, Depends, HTTPException, WebSocket, Query
from typing import Optional
import json
import asyncio

from pydantic import BaseModel
from backend.app.services.inference_service import InferenceService, get_inference_service
from backend.app.services.cache_service import get_cache_service
from backend.app.middleware.auth import get_current_user, get_current_user_optional, User

router = APIRouter()


# --- リクエスト/レスポンススキーマの定義 ---
from typing import Optional, Literal


class QuestionRequest(BaseModel):
    question: str
    session_id: Optional[str] = None
    domain_id: str = "medical"
    model_id: Optional[str] = None
    stream: bool = False
    rag_mode: Literal["direct", "rag"] = "rag"


class QuestionResponse(BaseModel):
    session_id: str
    question: str
    response: str
    status: str
    confidence: Optional[float] = None
    memory_augmented: Optional[bool] = None
    thinking: Optional[str] = None
    model_used: Optional[str] = None


# --- APIエンドポイント ---
@router.post("/", response_model=QuestionResponse)
async def submit_question(
    request: QuestionRequest,
    current_user: Optional[User] = Depends(get_current_user_optional),
    service: InferenceService = Depends(get_inference_service)
):
    """
    質問を提出し、推論エンジンで処理。
    ゲストユーザーでもアクセス可能。
    """
    # ゲストユーザーの場合は "guest" として扱う
    user_id = current_user.id if current_user else "guest"

    # Session IDがなければ新規生成
    session_id = request.session_id if request.session_id else f"sess_{user_id}_{hash(request.question)}"

    try:
        # 依存性注入されたInferenceServiceを呼び出す
        result = await service.process_question(
            question=request.question,
            user_id=user_id,
            session_id=session_id,
            domain_id=request.domain_id,
            model_id=request.model_id,
            rag_mode=request.rag_mode
        )

        return QuestionResponse(
            session_id=session_id,
            question=request.question,
            response=result.get("answer", result.get("response", "回答が得られませんでした。")),
            status=result.get("status", "error"),
            confidence=result.get("confidence"),
            memory_augmented=result.get("memory_augmented"),
            thinking=result.get("thinking"),
            model_used=result.get("model_used")
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@router.websocket("/ws/{session_id}")
async def websocket_endpoint(
    websocket: WebSocket,
    session_id: str
):
    """
    WebSocketで回答のストリーミング配信。
    トークン単位でリアルタイムに配信。
    """
    await websocket.accept()

    # InferenceServiceのインスタンスを作成
    service = get_inference_service()

    try:
        # 接続確認メッセージ
        await websocket.send_json({
            "type": "connected",
            "session_id": session_id,
            "message": "WebSocket connected"
        })

        while True:
            # クライアントからのメッセージを待機
            data = await websocket.receive_json()

            if data.get("type") == "question":
                question = data.get("question", "")
                domain_id = data.get("domain_id", "medical")
                model_id = data.get("model_id")
                use_streaming = data.get("stream", True)
                rag_mode = data.get("rag_mode", "rag")

                # 処理開始通知
                await websocket.send_json({
                    "type": "processing",
                    "message": "Processing your question..."
                })

                try:
                    if use_streaming:
                        # ストリーミングモードで生成
                        await websocket.send_json({
                            "type": "thinking",
                            "step": "Initializing model..."
                        })

                        generated_tokens = []
                        async for chunk in service.stream_tokens(
                            session_id=session_id,
                            question=question,
                            domain_id=domain_id,
                            model_id=model_id,
                            rag_mode=rag_mode
                        ):
                            chunk_type = chunk.get("type", "")

                            if chunk_type == "token":
                                # トークンをリアルタイムで送信
                                token = chunk.get("content", "")
                                generated_tokens.append(token)
                                await websocket.send_json({
                                    "type": "token",
                                    "content": token
                                })

                            elif chunk_type == "thinking":
                                await websocket.send_json({
                                    "type": "thinking",
                                    "step": chunk.get("content", "")
                                })

                            elif chunk_type == "complete":
                                # 完了メッセージ
                                await websocket.send_json({
                                    "type": "response",
                                    "session_id": session_id,
                                    "question": question,
                                    "response": chunk.get("content", ""),
                                    "status": "success"
                                })
                                break

                            elif chunk_type == "error":
                                await websocket.send_json({
                                    "type": "error",
                                    "error": chunk.get("content", chunk.get("message", "Unknown error"))
                                })
                                break

                            elif chunk_type == "heartbeat":
                                # ハートビートは無視(接続維持のため)
                                continue

                            elif chunk_type == "start":
                                await websocket.send_json({
                                    "type": "thinking",
                                    "step": "Starting generation..."
                                })

                    else:
                        # 非ストリーミングモード
                        await websocket.send_json({
                            "type": "thinking",
                            "step": "Processing..."
                        })

                        result = await service.process_question(
                            question=question,
                            user_id="ws_user",
                            session_id=session_id,
                            domain_id=domain_id,
                            model_id=model_id,
                            rag_mode=data.get("rag_mode", "rag")
                        )

                        # 最終回答を送信
                        await websocket.send_json({
                            "type": "response",
                            "session_id": session_id,
                            "question": question,
                            "response": result.get("answer", result.get("response", "")),
                            "status": result.get("status", "error"),
                            "confidence": result.get("confidence"),
                            "thinking": result.get("thinking"),
                            "model_used": result.get("model_used")
                        })

                except Exception as e:
                    import traceback
                    traceback.print_exc()
                    await websocket.send_json({
                        "type": "error",
                        "error": str(e)
                    })

            elif data.get("type") == "ping":
                await websocket.send_json({"type": "pong"})

            elif data.get("type") == "close":
                break

    except Exception as e:
        try:
            await websocket.send_json({
                "type": "error",
                "error": str(e)
            })
        except:
            pass
    finally:
        try:
            await websocket.close()
        except:
            pass