MarcoLeung052 commited on
Commit
e4d2a60
·
verified ·
1 Parent(s): c770c0c

Update api_server.py

Browse files
Files changed (1) hide show
  1. api_server.py +17 -51
api_server.py CHANGED
@@ -3,25 +3,14 @@
3
  from fastapi import FastAPI, HTTPException, Depends
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
6
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
7
- import torch
8
  from backend.agent import run_agent
9
 
10
- # =================================================================
11
- # 1. 應用程式初始化與模型載入
12
- # =================================================================
13
- app = FastAPI(title="GPT-2 Nursing Completion API")
14
 
15
- # 設置 CORS:允許前端頁面 (localhost 或您的服務器 IP) 訪問
16
- # ⚠️ 注意:在生產環境中,請將 "http://localhost:5500" 替換為您的前端域名!
17
  origins = [
18
- #"http://localhost:5500", # 假設您使用 VS Code Live Server 或類似工具
19
- #"http://127.0.0.1:5500",
20
- # 這是您的 GitHub Pages 域名(標準格式)
21
- "https://marcoleung052.github.io",
22
- # 這是您的 GitHub Pages 子專案路徑 (如果使用子路徑)
23
  "https://marcoleung052.github.io/NursingRecordCompletion_test",
24
- "*" # 為了測試方便,暫時允許所有來源
25
  ]
26
 
27
  app.add_middleware(
@@ -32,64 +21,41 @@ app.add_middleware(
32
  allow_headers=["*"],
33
  )
34
 
35
- # 全局變數用於存儲模型和分詞器
36
- tokenizer = None
37
- model = None
38
- MODEL_PATH = "gpt2" # 這裡可以替換為您微調後的模型資料夾路徑
39
-
40
- @app.on_event("startup")
41
- async def load_model():
42
- """在應用啟動時載入 GPT-2 模型"""
43
- global tokenizer, model
44
- try:
45
- # 載入分詞器
46
- tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH)
47
-
48
- # 載入預訓練模型或您微調的模型權重
49
- # 如果您的記憶體允許,可以考慮使用 GPU
50
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
- model = GPT2LMHeadModel.from_pretrained(MODEL_PATH)
52
- # model.to(device)
53
- model.eval() # 設定為評估模式
54
-
55
- print(f"✅ GPT-2 模型 {MODEL_PATH} 載入成功!")
56
- except Exception as e:
57
- print(f"❌ 模型載入失敗,請檢查 MODEL_PATH 或依賴庫是否安裝:{e}")
58
-
59
-
60
- # =================================================================
61
- # 2. API 請求與響應格式
62
- # =================================================================
63
  class PredictionRequest(BaseModel):
64
- """前端發送的請求體格式"""
65
  prompt: str
66
  patient_id: str | None = None
67
  model: str | None = "gpt2-nursing"
68
 
69
  class PredictionResponse(BaseModel):
70
- """後端回傳的響應體格式"""
71
  completions: list[str]
72
 
73
- # =================================================================
74
- # 3. 核心 API 端點 (已修改為生成 3 個序列)
75
- # =================================================================
 
76
  @app.post("/api/predict", response_model=PredictionResponse)
77
  def predict_completion(request: PredictionRequest):
78
 
79
  input_text = request.prompt
80
 
81
- # 交給 agent 處理(固定 or AI)
82
  result = run_agent(input_text)
83
 
84
- # agent 回傳的結果格式統一 list
85
  return {"completions": result}
86
 
87
- # 運行伺服器
 
 
 
88
  if __name__ == "__main__":
89
  import uvicorn
90
- # host 0.0.0.0 允許外部訪問,port 8000 與前端設定一致
91
  uvicorn.run("api_server:app", host="0.0.0.0", port=8000, reload=True)
92
 
 
93
  # =================================================================
94
  # 4. 資料庫設定(SQLite + SQLAlchemy)
95
  # =================================================================
 
3
  from fastapi import FastAPI, HTTPException, Depends
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
 
 
6
  from backend.agent import run_agent
7
 
8
+ app = FastAPI(title="Nursing Copilot API")
 
 
 
9
 
 
 
10
  origins = [
11
+ "https://marcoleung052.github.io",
 
 
 
 
12
  "https://marcoleung052.github.io/NursingRecordCompletion_test",
13
+ "*"
14
  ]
15
 
16
  app.add_middleware(
 
21
  allow_headers=["*"],
22
  )
23
 
24
+ # -----------------------------
25
+ # Request / Response Models
26
+ # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class PredictionRequest(BaseModel):
 
28
  prompt: str
29
  patient_id: str | None = None
30
  model: str | None = "gpt2-nursing"
31
 
32
  class PredictionResponse(BaseModel):
 
33
  completions: list[str]
34
 
35
+
36
+ # -----------------------------
37
+ # API Endpoint
38
+ # -----------------------------
39
  @app.post("/api/predict", response_model=PredictionResponse)
40
  def predict_completion(request: PredictionRequest):
41
 
42
  input_text = request.prompt
43
 
44
+ # 交給 agent(固定 or AI)
45
  result = run_agent(input_text)
46
 
47
+ # agent 統一回傳 list
48
  return {"completions": result}
49
 
50
+
51
+ # -----------------------------
52
+ # Run server
53
+ # -----------------------------
54
  if __name__ == "__main__":
55
  import uvicorn
 
56
  uvicorn.run("api_server:app", host="0.0.0.0", port=8000, reload=True)
57
 
58
+
59
  # =================================================================
60
  # 4. 資料庫設定(SQLite + SQLAlchemy)
61
  # =================================================================