kofdai's picture
Upload folder using huggingface_hub
594ed40 verified
from fastapi import APIRouter, Depends, HTTPException, WebSocket, Query
from typing import Optional
import json
import asyncio
from pydantic import BaseModel
from backend.app.services.inference_service import InferenceService, get_inference_service
from backend.app.services.cache_service import get_cache_service
from backend.app.middleware.auth import get_current_user, get_current_user_optional, User
router = APIRouter()
# --- リクエスト/レスポンススキーマの定義 ---
from typing import Optional, Literal
class QuestionRequest(BaseModel):
question: str
session_id: Optional[str] = None
domain_id: str = "medical"
model_id: Optional[str] = None
stream: bool = False
rag_mode: Literal["direct", "rag"] = "rag"
class QuestionResponse(BaseModel):
session_id: str
question: str
response: str
status: str
confidence: Optional[float] = None
memory_augmented: Optional[bool] = None
thinking: Optional[str] = None
model_used: Optional[str] = None
# --- APIエンドポイント ---
@router.post("/", response_model=QuestionResponse)
async def submit_question(
request: QuestionRequest,
current_user: Optional[User] = Depends(get_current_user_optional),
service: InferenceService = Depends(get_inference_service)
):
"""
質問を提出し、推論エンジンで処理。
ゲストユーザーでもアクセス可能。
"""
# ゲストユーザーの場合は "guest" として扱う
user_id = current_user.id if current_user else "guest"
# Session IDがなければ新規生成
session_id = request.session_id if request.session_id else f"sess_{user_id}_{hash(request.question)}"
try:
# 依存性注入されたInferenceServiceを呼び出す
result = await service.process_question(
question=request.question,
user_id=user_id,
session_id=session_id,
domain_id=request.domain_id,
model_id=request.model_id,
rag_mode=request.rag_mode
)
return QuestionResponse(
session_id=session_id,
question=request.question,
response=result.get("answer", result.get("response", "回答が得られませんでした。")),
status=result.get("status", "error"),
confidence=result.get("confidence"),
memory_augmented=result.get("memory_augmented"),
thinking=result.get("thinking"),
model_used=result.get("model_used")
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.websocket("/ws/{session_id}")
async def websocket_endpoint(
websocket: WebSocket,
session_id: str
):
"""
WebSocketで回答のストリーミング配信。
トークン単位でリアルタイムに配信。
"""
await websocket.accept()
# InferenceServiceのインスタンスを作成
service = get_inference_service()
try:
# 接続確認メッセージ
await websocket.send_json({
"type": "connected",
"session_id": session_id,
"message": "WebSocket connected"
})
while True:
# クライアントからのメッセージを待機
data = await websocket.receive_json()
if data.get("type") == "question":
question = data.get("question", "")
domain_id = data.get("domain_id", "medical")
model_id = data.get("model_id")
use_streaming = data.get("stream", True)
rag_mode = data.get("rag_mode", "rag")
# 処理開始通知
await websocket.send_json({
"type": "processing",
"message": "Processing your question..."
})
try:
if use_streaming:
# ストリーミングモードで生成
await websocket.send_json({
"type": "thinking",
"step": "Initializing model..."
})
generated_tokens = []
async for chunk in service.stream_tokens(
session_id=session_id,
question=question,
domain_id=domain_id,
model_id=model_id,
rag_mode=rag_mode
):
chunk_type = chunk.get("type", "")
if chunk_type == "token":
# トークンをリアルタイムで送信
token = chunk.get("content", "")
generated_tokens.append(token)
await websocket.send_json({
"type": "token",
"content": token
})
elif chunk_type == "thinking":
await websocket.send_json({
"type": "thinking",
"step": chunk.get("content", "")
})
elif chunk_type == "complete":
# 完了メッセージ
await websocket.send_json({
"type": "response",
"session_id": session_id,
"question": question,
"response": chunk.get("content", ""),
"status": "success"
})
break
elif chunk_type == "error":
await websocket.send_json({
"type": "error",
"error": chunk.get("content", chunk.get("message", "Unknown error"))
})
break
elif chunk_type == "heartbeat":
# ハートビートは無視(接続維持のため)
continue
elif chunk_type == "start":
await websocket.send_json({
"type": "thinking",
"step": "Starting generation..."
})
else:
# 非ストリーミングモード
await websocket.send_json({
"type": "thinking",
"step": "Processing..."
})
result = await service.process_question(
question=question,
user_id="ws_user",
session_id=session_id,
domain_id=domain_id,
model_id=model_id,
rag_mode=data.get("rag_mode", "rag")
)
# 最終回答を送信
await websocket.send_json({
"type": "response",
"session_id": session_id,
"question": question,
"response": result.get("answer", result.get("response", "")),
"status": result.get("status", "error"),
"confidence": result.get("confidence"),
"thinking": result.get("thinking"),
"model_used": result.get("model_used")
})
except Exception as e:
import traceback
traceback.print_exc()
await websocket.send_json({
"type": "error",
"error": str(e)
})
elif data.get("type") == "ping":
await websocket.send_json({"type": "pong"})
elif data.get("type") == "close":
break
except Exception as e:
try:
await websocket.send_json({
"type": "error",
"error": str(e)
})
except:
pass
finally:
try:
await websocket.close()
except:
pass