|
|
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Depends |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from backend.skill_loader import scan_all_skills |
|
|
from backend.agent import set_trigger_map, run_agent |
|
|
from typing import List, Dict, Any |
|
|
|
|
|
|
|
|
app = FastAPI(title="Nursing Copilot API") |
|
|
|
|
|
origins = [ |
|
|
"https://marcoleung052.github.io", |
|
|
"https://marcoleung052.github.io/NursingRecordCompletion_test", |
|
|
"*" |
|
|
] |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=origins, |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PredictionRequest(BaseModel): |
|
|
prompt: str |
|
|
patient_id: str | None = None |
|
|
model: str | None = "gpt2-nursing" |
|
|
|
|
|
class PredictionResponse(BaseModel): |
|
|
completions: List[Dict[str, Any]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
def startup_event(): |
|
|
trigger_map = scan_all_skills() |
|
|
set_trigger_map(trigger_map) |
|
|
|
|
|
|
|
|
from backend.ai_output import run_ai_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run("api_server:app", host="0.0.0.0", port=8000, reload=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import create_engine |
|
|
from sqlalchemy.orm import sessionmaker, Session |
|
|
import sys |
|
|
import os |
|
|
sys.path.append(os.path.dirname(__file__)) |
|
|
|
|
|
from models import Base, Patient, Nurse, Record |
|
|
|
|
|
DATABASE_URL = "sqlite:///./nursing.db" |
|
|
|
|
|
engine = create_engine( |
|
|
DATABASE_URL, connect_args={"check_same_thread": False} |
|
|
) |
|
|
|
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
|
|
|
|
|
|
|
|
Base.metadata.create_all(bind=engine) |
|
|
|
|
|
|
|
|
def get_db(): |
|
|
db = SessionLocal() |
|
|
try: |
|
|
yield db |
|
|
finally: |
|
|
db.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
class LoginRequest(BaseModel): |
|
|
username: str |
|
|
password: str |
|
|
|
|
|
|
|
|
class PatientCreate(BaseModel): |
|
|
name: str |
|
|
mrn: str | None = None |
|
|
birth: str | None = None |
|
|
gender: str | None = None |
|
|
phone: str | None = None |
|
|
email: str | None = None |
|
|
|
|
|
emg_name: str | None = None |
|
|
emg_phone: str | None = None |
|
|
emg_relation: str | None = None |
|
|
|
|
|
room: str |
|
|
department: str |
|
|
doctor: str | None = None |
|
|
diagnosis: str |
|
|
risk: str |
|
|
admit_date: str |
|
|
|
|
|
|
|
|
|
|
|
class NurseCreate(BaseModel): |
|
|
name: str |
|
|
staff_id: str |
|
|
department: str | None = None |
|
|
position: str | None = None |
|
|
phone: str | None = None |
|
|
email: str | None = None |
|
|
password: str | None = None |
|
|
|
|
|
|
|
|
|
|
|
class RecordCreate(BaseModel): |
|
|
patient_id: int |
|
|
nurse_id: int |
|
|
content: str |
|
|
created_at: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/api/predict") |
|
|
def predict_completion(request: PredictionRequest, db: Session = Depends(get_db)): |
|
|
|
|
|
print("收到的 request:", request.dict()) |
|
|
|
|
|
if not request.prompt or request.prompt.strip() == "": |
|
|
raise HTTPException(status_code=400, detail="prompt 不可為空") |
|
|
|
|
|
if request.patient_id is None: |
|
|
raise HTTPException(status_code=400, detail="需要 patient_id") |
|
|
|
|
|
try: |
|
|
patient_id = int(request.patient_id) |
|
|
except: |
|
|
raise HTTPException(status_code=400, detail="patient_id 必須是整數") |
|
|
|
|
|
|
|
|
result = run_agent( |
|
|
user_input=request.prompt, |
|
|
patient_id=patient_id, |
|
|
db=db |
|
|
) |
|
|
|
|
|
return {"completions": [result]} |
|
|
|
|
|
@app.post("/patients") |
|
|
def create_patient(data: PatientCreate, db: Session = Depends(get_db)): |
|
|
patient = Patient(**data.dict()) |
|
|
db.add(patient) |
|
|
db.commit() |
|
|
db.refresh(patient) |
|
|
return patient |
|
|
|
|
|
|
|
|
@app.get("/patients") |
|
|
def list_patients(db: Session = Depends(get_db)): |
|
|
return db.query(Patient).all() |
|
|
|
|
|
|
|
|
@app.get("/patients/{patient_id}") |
|
|
def get_patient(patient_id: int, db: Session = Depends(get_db)): |
|
|
patient = db.query(Patient).filter(Patient.id == patient_id).first() |
|
|
if not patient: |
|
|
raise HTTPException(status_code=404, detail="找不到病患") |
|
|
return patient |
|
|
|
|
|
@app.put("/patients/{patient_id}") |
|
|
def update_patient(patient_id: int, data: PatientCreate, db: Session = Depends(get_db)): |
|
|
patient = db.query(Patient).filter(Patient.id == patient_id).first() |
|
|
if not patient: |
|
|
raise HTTPException(status_code=404, detail="找不到病患") |
|
|
|
|
|
for key, value in data.dict().items(): |
|
|
setattr(patient, key, value) |
|
|
|
|
|
db.commit() |
|
|
db.refresh(patient) |
|
|
return patient |
|
|
|
|
|
@app.delete("/patients/{patient_id}") |
|
|
def delete_patient(patient_id: int, db: Session = Depends(get_db)): |
|
|
patient = db.query(Patient).filter(Patient.id == patient_id).first() |
|
|
if not patient: |
|
|
raise HTTPException(status_code=404, detail="找不到病患") |
|
|
db.delete(patient) |
|
|
db.commit() |
|
|
return {"message": "病患已刪除"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/nurses") |
|
|
def create_nurse(data: NurseCreate, db: Session = Depends(get_db)): |
|
|
nurse_data = data.dict() |
|
|
nurse_data["password"] = data.staff_id |
|
|
nurse = Nurse(**nurse_data) |
|
|
db.add(nurse) |
|
|
db.commit() |
|
|
db.refresh(nurse) |
|
|
return nurse |
|
|
|
|
|
@app.delete("/nurses/{nurse_id}") |
|
|
def delete_nurse(nurse_id: int, db: Session = Depends(get_db)): |
|
|
nurse = db.query(Nurse).filter(Nurse.id == nurse_id).first() |
|
|
if not nurse: |
|
|
raise HTTPException(status_code=404, detail="找不到護理師") |
|
|
|
|
|
db.delete(nurse) |
|
|
db.commit() |
|
|
return {"message": "護理師已刪除"} |
|
|
|
|
|
@app.get("/nurses") |
|
|
def list_nurses(db: Session = Depends(get_db)): |
|
|
return db.query(Nurse).all() |
|
|
|
|
|
@app.get("/nurses/{nurse_id}") |
|
|
def get_nurse(nurse_id: int, db: Session = Depends(get_db)): |
|
|
nurse = db.query(Nurse).filter(Nurse.id == nurse_id).first() |
|
|
if not nurse: |
|
|
raise HTTPException(status_code=404, detail="找不到護理師") |
|
|
return nurse |
|
|
|
|
|
@app.delete("/nurses/{nurse_id}") |
|
|
def delete_nurse(nurse_id: int, db: Session = Depends(get_db)): |
|
|
nurse = db.query(Nurse).filter(Nurse.id == nurse_id).first() |
|
|
if not nurse: |
|
|
raise HTTPException(status_code=404, detail="找不到護理師") |
|
|
|
|
|
db.delete(nurse) |
|
|
db.commit() |
|
|
return {"message": "護理師已刪除"} |
|
|
|
|
|
@app.put("/nurses/{nurse_id}") |
|
|
def update_nurse(nurse_id: int, data: NurseCreate, db: Session = Depends(get_db)): |
|
|
nurse = db.query(Nurse).filter(Nurse.id == nurse_id).first() |
|
|
if not nurse: |
|
|
raise HTTPException(status_code=404, detail="找不到護理師") |
|
|
|
|
|
for key, value in data.dict().items(): |
|
|
setattr(nurse, key, value) |
|
|
|
|
|
db.commit() |
|
|
db.refresh(nurse) |
|
|
return nurse |
|
|
|
|
|
@app.put("/nurses/{nurse_id}/reset-password") |
|
|
def reset_password(nurse_id: int, db: Session = Depends(get_db)): |
|
|
nurse = db.query(Nurse).filter(Nurse.id == nurse_id).first() |
|
|
if not nurse: |
|
|
raise HTTPException(status_code=404, detail="找不到護理師") |
|
|
|
|
|
nurse.password = nurse.staff_id |
|
|
db.commit() |
|
|
return {"message": "密碼已重設"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/records") |
|
|
def create_record(data: RecordCreate, db: Session = Depends(get_db)): |
|
|
record = Record(**data.dict()) |
|
|
db.add(record) |
|
|
db.commit() |
|
|
db.refresh(record) |
|
|
return record |
|
|
|
|
|
|
|
|
@app.get("/records/{patient_id}") |
|
|
def list_records(patient_id: int, db: Session = Depends(get_db)): |
|
|
return db.query(Record).filter(Record.patient_id == patient_id).all() |
|
|
|
|
|
@app.get("/records/detail/{record_id}") |
|
|
def get_record_detail(record_id: int, db: Session = Depends(get_db)): |
|
|
record = db.query(Record).filter(Record.id == record_id).first() |
|
|
if not record: |
|
|
raise HTTPException(status_code=404, detail="找不到紀錄") |
|
|
return record |
|
|
|
|
|
@app.put("/records/{record_id}") |
|
|
def update_record(record_id: int, data: RecordCreate, db: Session = Depends(get_db)): |
|
|
record = db.query(Record).filter(Record.id == record_id).first() |
|
|
if not record: |
|
|
raise HTTPException(status_code=404, detail="找不到紀錄") |
|
|
|
|
|
for key, value in data.dict().items(): |
|
|
setattr(record, key, value) |
|
|
|
|
|
db.commit() |
|
|
db.refresh(record) |
|
|
return record |
|
|
|
|
|
@app.delete("/records/{record_id}") |
|
|
def delete_record(record_id: int, db: Session = Depends(get_db)): |
|
|
record = db.query(Record).filter(Record.id == record_id).first() |
|
|
if not record: |
|
|
raise HTTPException(status_code=404, detail="找不到紀錄") |
|
|
|
|
|
db.delete(record) |
|
|
db.commit() |
|
|
return {"message": "紀錄已刪除"} |
|
|
|
|
|
@app.post("/login") |
|
|
def login(data: LoginRequest, db: Session = Depends(get_db)): |
|
|
|
|
|
|
|
|
if data.username == "admin" and data.password == "1234": |
|
|
return {"role": "admin", "name": "Admin", "id": 0} |
|
|
|
|
|
|
|
|
nurse = db.query(Nurse).filter(Nurse.staff_id == data.username).first() |
|
|
|
|
|
if not nurse: |
|
|
raise HTTPException(status_code=401, detail="帳號不存在") |
|
|
|
|
|
if nurse.password != data.password: |
|
|
raise HTTPException(status_code=401, detail="密碼錯誤") |
|
|
|
|
|
return { |
|
|
"role": "nurse", |
|
|
"name": nurse.name, |
|
|
"id": nurse.id |
|
|
} |
|
|
|
|
|
@app.get("/current-user") |
|
|
def current_user(token: str | None = None, db: Session = Depends(get_db)): |
|
|
|
|
|
if token == "admin": |
|
|
return {"role": "admin", "name": "Admin", "id": 0} |
|
|
|
|
|
if token is None: |
|
|
raise HTTPException(status_code=401, detail="未登入") |
|
|
|
|
|
nurse = db.query(Nurse).filter(Nurse.id == int(token)).first() |
|
|
|
|
|
if not nurse: |
|
|
raise HTTPException(status_code=401, detail="登入者不存在") |
|
|
|
|
|
return { |
|
|
"role": "nurse", |
|
|
"name": nurse.name, |
|
|
"id": nurse.id |
|
|
} |
|
|
|
|
|
from fastapi.responses import FileResponse |
|
|
@app.get("/download-db") |
|
|
def download_db(): |
|
|
return FileResponse("nursing.db", filename="nursing.db") |