nursing-copilot-api / api_server.py
MarcoLeung052's picture
Update api_server.py
cbc7067 verified
# api_server.py
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from backend.skill_loader import scan_all_skills
from backend.agent import set_trigger_map, run_agent
from typing import List, Dict, Any
app = FastAPI(title="Nursing Copilot API")
origins = [
"https://marcoleung052.github.io",
"https://marcoleung052.github.io/NursingRecordCompletion_test",
"*"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -----------------------------
# Request / Response Models
# -----------------------------
class PredictionRequest(BaseModel):
prompt: str
patient_id: str | None = None
model: str | None = "gpt2-nursing"
class PredictionResponse(BaseModel):
completions: List[Dict[str, Any]]
# -----------------------------
# API Endpoint
# -----------------------------
@app.on_event("startup")
def startup_event():
trigger_map = scan_all_skills() # ⭐ 自動掃描 + 生成 TRIGGER_MAP.json
set_trigger_map(trigger_map) # ⭐ 注入到 agent.py
from backend.ai_output import run_ai_output
# -----------------------------
# Run server
# -----------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("api_server:app", host="0.0.0.0", port=8000, reload=True)
# =================================================================
# 4. 資料庫設定(SQLite + SQLAlchemy)
# =================================================================
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
import sys
import os
sys.path.append(os.path.dirname(__file__))
from models import Base, Patient, Nurse, Record
DATABASE_URL = "sqlite:///./nursing.db"
engine = create_engine(
DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 建立資料表
Base.metadata.create_all(bind=engine)
# 🔥 正確的 get_db(一定要放在 endpoints 之前)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# =================================================================
# 5. Schemas
# =================================================================
from pydantic import BaseModel
# 管理員
class LoginRequest(BaseModel):
username: str
password: str
# 病患
class PatientCreate(BaseModel):
name: str
mrn: str | None = None
birth: str | None = None
gender: str | None = None
phone: str | None = None
email: str | None = None
emg_name: str | None = None
emg_phone: str | None = None
emg_relation: str | None = None
room: str
department: str
doctor: str | None = None
diagnosis: str
risk: str
admit_date: str
# 護士
class NurseCreate(BaseModel):
name: str
staff_id: str
department: str | None = None
position: str | None = None
phone: str | None = None
email: str | None = None
password: str | None = None
# 護理紀錄
class RecordCreate(BaseModel):
patient_id: int
nurse_id: int
content: str
created_at: str
# =================================================================
# 6. 病患 API
# =================================================================
@app.post("/api/predict")
def predict_completion(request: PredictionRequest, db: Session = Depends(get_db)):
print("收到的 request:", request.dict())
if not request.prompt or request.prompt.strip() == "":
raise HTTPException(status_code=400, detail="prompt 不可為空")
if request.patient_id is None:
raise HTTPException(status_code=400, detail="需要 patient_id")
try:
patient_id = int(request.patient_id)
except:
raise HTTPException(status_code=400, detail="patient_id 必須是整數")
# ⭐ 只呼叫 agent,並把 db、patient_id 傳進去
result = run_agent(
user_input=request.prompt,
patient_id=patient_id,
db=db
)
return {"completions": [result]}
@app.post("/patients")
def create_patient(data: PatientCreate, db: Session = Depends(get_db)):
patient = Patient(**data.dict())
db.add(patient)
db.commit()
db.refresh(patient)
return patient
@app.get("/patients")
def list_patients(db: Session = Depends(get_db)):
return db.query(Patient).all()
@app.get("/patients/{patient_id}")
def get_patient(patient_id: int, db: Session = Depends(get_db)):
patient = db.query(Patient).filter(Patient.id == patient_id).first()
if not patient:
raise HTTPException(status_code=404, detail="找不到病患")
return patient
@app.put("/patients/{patient_id}")
def update_patient(patient_id: int, data: PatientCreate, db: Session = Depends(get_db)):
patient = db.query(Patient).filter(Patient.id == patient_id).first()
if not patient:
raise HTTPException(status_code=404, detail="找不到病患")
for key, value in data.dict().items():
setattr(patient, key, value)
db.commit()
db.refresh(patient)
return patient
@app.delete("/patients/{patient_id}")
def delete_patient(patient_id: int, db: Session = Depends(get_db)):
patient = db.query(Patient).filter(Patient.id == patient_id).first()
if not patient:
raise HTTPException(status_code=404, detail="找不到病患")
db.delete(patient)
db.commit()
return {"message": "病患已刪除"}
# =================================================================
# 7. 護士 API
# =================================================================
@app.post("/nurses")
def create_nurse(data: NurseCreate, db: Session = Depends(get_db)):
nurse_data = data.dict()
nurse_data["password"] = data.staff_id # 預設密碼 = staff_id
nurse = Nurse(**nurse_data)
db.add(nurse)
db.commit()
db.refresh(nurse)
return nurse
@app.delete("/nurses/{nurse_id}")
def delete_nurse(nurse_id: int, db: Session = Depends(get_db)):
nurse = db.query(Nurse).filter(Nurse.id == nurse_id).first()
if not nurse:
raise HTTPException(status_code=404, detail="找不到護理師")
db.delete(nurse)
db.commit()
return {"message": "護理師已刪除"}
@app.get("/nurses")
def list_nurses(db: Session = Depends(get_db)):
return db.query(Nurse).all()
@app.get("/nurses/{nurse_id}")
def get_nurse(nurse_id: int, db: Session = Depends(get_db)):
nurse = db.query(Nurse).filter(Nurse.id == nurse_id).first()
if not nurse:
raise HTTPException(status_code=404, detail="找不到護理師")
return nurse
@app.delete("/nurses/{nurse_id}")
def delete_nurse(nurse_id: int, db: Session = Depends(get_db)):
nurse = db.query(Nurse).filter(Nurse.id == nurse_id).first()
if not nurse:
raise HTTPException(status_code=404, detail="找不到護理師")
db.delete(nurse)
db.commit()
return {"message": "護理師已刪除"}
@app.put("/nurses/{nurse_id}")
def update_nurse(nurse_id: int, data: NurseCreate, db: Session = Depends(get_db)):
nurse = db.query(Nurse).filter(Nurse.id == nurse_id).first()
if not nurse:
raise HTTPException(status_code=404, detail="找不到護理師")
for key, value in data.dict().items():
setattr(nurse, key, value)
db.commit()
db.refresh(nurse)
return nurse
@app.put("/nurses/{nurse_id}/reset-password")
def reset_password(nurse_id: int, db: Session = Depends(get_db)):
nurse = db.query(Nurse).filter(Nurse.id == nurse_id).first()
if not nurse:
raise HTTPException(status_code=404, detail="找不到護理師")
nurse.password = nurse.staff_id # 密碼 = staff_id
db.commit()
return {"message": "密碼已重設"}
# =================================================================
# 8. 護理紀錄 API
# =================================================================
@app.post("/records")
def create_record(data: RecordCreate, db: Session = Depends(get_db)):
record = Record(**data.dict())
db.add(record)
db.commit()
db.refresh(record)
return record
@app.get("/records/{patient_id}")
def list_records(patient_id: int, db: Session = Depends(get_db)):
return db.query(Record).filter(Record.patient_id == patient_id).all()
@app.get("/records/detail/{record_id}")
def get_record_detail(record_id: int, db: Session = Depends(get_db)):
record = db.query(Record).filter(Record.id == record_id).first()
if not record:
raise HTTPException(status_code=404, detail="找不到紀錄")
return record
@app.put("/records/{record_id}")
def update_record(record_id: int, data: RecordCreate, db: Session = Depends(get_db)):
record = db.query(Record).filter(Record.id == record_id).first()
if not record:
raise HTTPException(status_code=404, detail="找不到紀錄")
for key, value in data.dict().items():
setattr(record, key, value)
db.commit()
db.refresh(record)
return record
@app.delete("/records/{record_id}")
def delete_record(record_id: int, db: Session = Depends(get_db)):
record = db.query(Record).filter(Record.id == record_id).first()
if not record:
raise HTTPException(status_code=404, detail="找不到紀錄")
db.delete(record)
db.commit()
return {"message": "紀錄已刪除"}
@app.post("/login")
def login(data: LoginRequest, db: Session = Depends(get_db)):
# 管理員登入
if data.username == "admin" and data.password == "1234":
return {"role": "admin", "name": "Admin", "id": 0}
# 護理師登入
nurse = db.query(Nurse).filter(Nurse.staff_id == data.username).first()
if not nurse:
raise HTTPException(status_code=401, detail="帳號不存在")
if nurse.password != data.password:
raise HTTPException(status_code=401, detail="密碼錯誤")
return {
"role": "nurse",
"name": nurse.name,
"id": nurse.id
}
@app.get("/current-user")
def current_user(token: str | None = None, db: Session = Depends(get_db)):
if token == "admin":
return {"role": "admin", "name": "Admin", "id": 0}
if token is None:
raise HTTPException(status_code=401, detail="未登入")
nurse = db.query(Nurse).filter(Nurse.id == int(token)).first()
if not nurse:
raise HTTPException(status_code=401, detail="登入者不存在")
return {
"role": "nurse",
"name": nurse.name,
"id": nurse.id
}
from fastapi.responses import FileResponse
@app.get("/download-db")
def download_db():
return FileResponse("nursing.db", filename="nursing.db")