MarcoLeung052 commited on
Commit
1d85dd7
·
verified ·
1 Parent(s): 657bed4

Upload 2 files

Browse files
Files changed (2) hide show
  1. api_server (1).py +124 -0
  2. requirements (1).txt +6 -0
api_server (1).py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # api_server.py
2
+
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
7
+ import torch
8
+
9
+ # =================================================================
10
+ # 1. 應用程式初始化與模型載入
11
+ # =================================================================
12
+ app = FastAPI(title="GPT-2 Nursing Completion API")
13
+
14
+ # 設置 CORS:允許前端頁面 (localhost 或您的服務器 IP) 訪問
15
+ # ⚠️ 注意:在生產環境中,請將 "http://localhost:5500" 替換為您的前端域名!
16
+ origins = [
17
+ #"http://localhost:5500", # 假設您使用 VS Code Live Server 或類似工具
18
+ #"http://127.0.0.1:5500",
19
+ "https://marcoleung052.github.io/NursingRecordCompletion_train//step7/%E8%AD%B7%E7%90%86%E7%B4%80%E9%8C%84%E7%B3%BB%E7%B5%B1demo.html",
20
+ "*" # 為了測試方便,暫時允許所有來源
21
+ ]
22
+
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=origins,
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ # 全局變數用於存儲模型和分詞器
32
+ tokenizer = None
33
+ model = None
34
+ MODEL_PATH = "gpt2" # 這裡可以替換為您微調後的模型資料夾路徑
35
+
36
+ @app.on_event("startup")
37
+ async def load_model():
38
+ """在應用啟動時載入 GPT-2 模型"""
39
+ global tokenizer, model
40
+ try:
41
+ # 載入分詞器
42
+ tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH)
43
+
44
+ # 載入預訓練模型或您微調的模型權重
45
+ # 如果您的記憶體允許,可以考慮使用 GPU
46
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+ model = GPT2LMHeadModel.from_pretrained(MODEL_PATH)
48
+ # model.to(device)
49
+ model.eval() # 設定為評估模式
50
+
51
+ print(f"✅ GPT-2 模型 {MODEL_PATH} 載入成功!")
52
+ except Exception as e:
53
+ print(f"❌ 模型載入失敗,請檢查 MODEL_PATH 或依賴庫是否安裝:{e}")
54
+
55
+
56
+ # =================================================================
57
+ # 2. API 請求與響應格式
58
+ # =================================================================
59
+ class PredictionRequest(BaseModel):
60
+ """前端發送的請求體格式"""
61
+ prompt: str
62
+ patient_id: str | None = None
63
+ model: str | None = "gpt2-nursing"
64
+
65
+ class PredictionResponse(BaseModel):
66
+ """後端回傳的響應體格式"""
67
+ completions: list[str]
68
+
69
+ # =================================================================
70
+ # 3. 核心 API 端點 (已修改為生成 3 個序列)
71
+ # =================================================================
72
+ @app.post("/api/predict", response_model=PredictionResponse)
73
+ def predict_completion(request: PredictionRequest):
74
+ """根據輸入提示詞生成 DART 護理紀錄"""
75
+ if model is None or tokenizer is None:
76
+ raise HTTPException(status_code=503, detail="AI 模型服務尚未準備就緒,請檢查後端日誌。")
77
+
78
+ input_text = request.prompt
79
+ if len(input_text) > 512:
80
+ raise HTTPException(status_code=400, detail="輸入過長,請限制在 512 個字元內。")
81
+
82
+ try:
83
+ input_ids = tokenizer.encode(input_text, return_tensors='pt', truncation=True)
84
+
85
+ # 🔥 核心修改:設置 num_return_sequences=3 來生成多個候選結果
86
+ output = model.generate(
87
+ input_ids,
88
+ max_length=len(input_text) + 150,
89
+ num_return_sequences=3, # <--- 輸出 3 個不同的補全結果
90
+ no_repeat_ngram_size=3,
91
+ do_sample=True,
92
+ top_k=50,
93
+ top_p=0.95,
94
+ temperature=0.8,
95
+ pad_token_id=tokenizer.eos_token_id
96
+ )
97
+
98
+ all_completions = []
99
+ for sequence in output:
100
+ generated_text = tokenizer.decode(sequence, skip_special_tokens=True)
101
+
102
+ # 確保內容以用戶的輸入為開頭
103
+ if generated_text.startswith(input_text):
104
+ all_completions.append(generated_text)
105
+
106
+ # 移除重複的結果並按長度排序
107
+ unique_completions = sorted(list(set(all_completions)), key=len, reverse=True)
108
+
109
+ if not unique_completions:
110
+ # 如果模型沒有生成任何有效的補全,則返回用戶輸入本身
111
+ return {"completions": [input_text]}
112
+
113
+ # 返回所有唯一的補全結果 (最多 3 個)
114
+ return {"completions": unique_completions}
115
+
116
+ except Exception as e:
117
+ print(f"推論過程發生錯誤: {e}")
118
+ raise HTTPException(status_code=500, detail=f"模型推論失敗:{str(e)[:50]}...")
119
+
120
+ # 運行伺服器
121
+ if __name__ == "__main__":
122
+ import uvicorn
123
+ # host 0.0.0.0 允許外部訪問,port 8000 與前端設定一致
124
+ uvicorn.run("api_server:app", host="0.0.0.0", port=8000, reload=True)
requirements (1).txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+ fastapi
3
+ uvicorn
4
+ torch
5
+ transformers
6
+ pydantic