File size: 2,897 Bytes
e458941
 
 
b4525ce
 
e458941
 
a6eee8d
 
e458941
 
 
a6eee8d
 
 
e458941
 
a6eee8d
 
 
 
 
 
 
e458941
 
b4525ce
 
 
 
 
 
 
 
e458941
 
a6eee8d
 
 
 
e458941
a6eee8d
 
 
 
 
 
 
 
 
 
 
b4525ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e458941
 
a6eee8d
e458941
a6eee8d
e458941
a6eee8d
 
 
 
 
 
 
 
e458941
 
 
 
 
a6eee8d
e458941
 
 
 
 
 
 
a6eee8d
 
 
 
 
 
 
 
 
e458941
a6eee8d
e458941
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import uvicorn
from llm.llm import VirtualNurseLLM
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import os
from dotenv import load_dotenv
load_dotenv()

# model: typhoon-v1.5x-70b-instruct
nurse_llm = VirtualNurseLLM(
    base_url="https://api.opentyphoon.ai/v1",
    model="typhoon-v1.5x-70b-instruct",
    api_key=os.getenv("TYPHOON_CHAT_KEY")
)

# model: OpenThaiGPT
# nurse_llm = VirtualNurseLLM(
#     base_url="https://api.aieat.or.th/v1",
#     model=".",
#     api_key="dummy"
# )

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class UserInput(BaseModel):
    user_input: str
    model_name: str = "typhoon-v1.5x-70b-instruct"

class NurseResponse(BaseModel):
    nurse_response: str

class EHRData(BaseModel):
    ehr_data: dict
    current_context: str
    current_prompt: str
    current_prompt_ehr: str
    current_patient_response: str
    current_question: str

class ChatHistory(BaseModel):
    chat_history: list
    
@app.get("/", response_class=HTMLResponse)
def read_index():
    return """
    <!DOCTYPE html>
    <html>
    <head>
        <title>MALI_NURSE API/title>
    </head>
    <body>
        <h1>Welcome to MALI_NURSE API</h1>
        <p>This is the index page. Use the link below to access the API docs:</p>
        <a href="/docs">Go to Swagger Docs UI</a>
    </body>
    </html>
    """

@app.get("/history")
def get_chat_history():
    return ChatHistory(chat_history = nurse_llm.chat_history)

@app.get("/details")
def get_ehr_data():
    return EHRData(
        ehr_data=nurse_llm.ehr_data,
        current_context=nurse_llm.current_context,
        current_prompt=nurse_llm.current_prompt,
        current_prompt_ehr=nurse_llm.current_prompt_ehr,
        current_patient_response=nurse_llm.current_patient_response,
        current_question=nurse_llm.current_question
    )

def toggle_debug():
    nurse_llm.debug = not nurse_llm.debug
    return {"debug_mode": "on" if nurse_llm.debug else "off"}


@app.post("/reset")
def data_reset():
    nurse_llm.reset()
    print("Chat history and EHR data have been reset.")

@app.post("/nurse_response")
def nurse_response(user_input: UserInput):
    """
    Models: "typhoon-v1.5x-70b-instruct (default)", "openthaigpt"
    """
    if user_input.model_name == "typhoon-v1.5x-70b-instruct":
        nurse_llm.model = "typhoon-v1.5x-70b-instruct"
    elif user_input.model_name == "openthaigpt":
        nurse_llm.model = "openthaigpt"
    else:
        return {"error": "Invalid model name"}
    response = nurse_llm.invoke(user_input.user_input)
    return NurseResponse(nurse_response = response)

if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)