File size: 7,672 Bytes
46461bf
5718ced
 
46461bf
 
 
5718ced
46461bf
5718ced
 
 
46461bf
5718ced
46461bf
5718ced
 
 
46461bf
 
 
 
 
753a52f
5718ced
 
 
 
 
 
753a52f
5718ced
46461bf
 
 
 
5718ced
 
 
 
 
 
 
 
753a52f
5718ced
753a52f
5718ced
753a52f
5718ced
753a52f
5718ced
46461bf
753a52f
46461bf
5718ced
46461bf
 
 
5718ced
46461bf
 
 
571dfaf
5718ced
46461bf
5718ced
46461bf
 
 
 
571dfaf
5718ced
46461bf
5718ced
46461bf
 
 
5718ced
571dfaf
5718ced
46461bf
5718ced
46461bf
 
 
5718ced
571dfaf
5718ced
46461bf
5718ced
46461bf
 
 
 
5718ced
 
46461bf
 
 
5718ced
 
 
 
46461bf
 
5718ced
 
 
46461bf
5718ced
46461bf
5718ced
 
 
 
46461bf
5718ced
46461bf
 
5718ced
 
 
 
 
46461bf
 
5718ced
46461bf
 
5718ced
 
46461bf
 
 
5718ced
 
 
46461bf
 
 
 
 
5718ced
46461bf
5718ced
46461bf
 
 
 
 
 
5718ced
46461bf
5718ced
46461bf
5718ced
 
 
 
 
46461bf
 
 
 
 
 
5718ced
 
 
 
 
 
 
 
46461bf
5718ced
 
 
46461bf
 
 
5718ced
46461bf
5718ced
2714b57
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import gradio as gr
import re
from transformers import BioGptTokenizer, BioGptForCausalLM, pipeline
from crewai import Agent, Task, Crew, Process
from crewai.llm import BaseLLM

# ── 1. Load BioGPT once ─────────────────────────────────────────────
MODEL_NAME = "microsoft/biogpt"
tokenizer = BioGptTokenizer.from_pretrained(MODEL_NAME)
model = BioGptForCausalLM.from_pretrained(MODEL_NAME)
model.eval()

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=250,
    do_sample=True,
    temperature=0.65,
    top_p=0.92,
    repetition_penalty=1.35,
)

def biogpt_complete(prompt: str, max_new=200) -> str:
    raw = pipe(prompt, max_new_tokens=max_new)[0]["generated_text"]
    gen = raw[len(prompt):].strip()
    if "." in gen:
        gen = gen[: gen.rfind(".")+1]
    return gen or "Insufficient data generated."

# ── 2. Wrap BioGPT in BaseLLM for CrewAI ─────────────────────────────
class BioGPTLLM(BaseLLM):
    def __init__(self):
        super().__init__(model="biogpt-local", temperature=0.65)

    def call(self, messages, **kwargs):
        # Extract the clinical case from the CrewAI task
        if isinstance(messages, str):
            case = messages
        else:
            case = " ".join([m.get("content","") for m in messages])
            m = re.search(r"(\d+-year-old .+?\.)", case)
            case = m.group(1).strip() if m else case[:200]

        role = kwargs.get("role", "general")
        if role == "diagnosis":
            prompt = f"A {case} The differential diagnosis includes:"
        elif role == "treatment":
            prompt = f"For a patient with {case} The recommended treatment includes:"
        elif role == "precaution":
            prompt = f"For a patient with {case} Key precautions and follow-up include:"
        else:
            prompt = f"The clinical findings for {case} suggest:"

        return biogpt_complete(prompt)

biogpt_llm = BioGPTLLM()

# ── 3. Agents ───────────────────────────────────────────────────────
def make_agents():
    diagnostician = Agent(
        role="Medical Diagnostician",
        backstory="You are a medical expert specializing in diagnosing patients based on symptoms.",
        goal="Produce a differential diagnosis list for the patient.",
        llm=biogpt_llm,
        verbose=False,
        allow_delegation=False,
    )
    treatment_planner = Agent(
        role="Treatment Planner",
        backstory="You are a medical expert specializing in suggesting treatments for patients.",
        goal="Recommend evidence-based treatment for the patient.",
        llm=biogpt_llm,
        verbose=False,
        allow_delegation=False,
    )
    precaution_advisor = Agent(
        role="Precaution Advisor",
        backstory="You are a medical expert who advises patients on safety, lifestyle, and follow-up.",
        goal="Provide safety, lifestyle, and follow-up guidance.",
        llm=biogpt_llm,
        verbose=False,
        allow_delegation=False,
    )
    coordinator = Agent(
        role="Medical Coordinator",
        backstory="You combine outputs from other medical agents into a concise summary.",
        goal="Combine outputs into a concise three-section summary.",
        llm=biogpt_llm,
        verbose=False,
        allow_delegation=False,
    )
    return diagnostician, treatment_planner, precaution_advisor, coordinator

# ── 4. CrewAI Tasks ────────────────────────────────────────────────
def make_tasks(symptoms, age, gender, diag, treat, prec, coord):
    case = f"{age}-year-old {gender} presenting with: {symptoms}."

    task_diagnose = Task(
        description=case,
        expected_output="Differential diagnoses list",
        agent=diag,
        kwargs={"role":"diagnosis"}
    )
    task_treat = Task(
        description=case,
        expected_output="Treatment recommendations",
        agent=treat,
        context=[task_diagnose],
        kwargs={"role":"treatment"}
    )
    task_prec = Task(
        description=case,
        expected_output="Precautions",
        agent=prec,
        context=[task_diagnose],
        kwargs={"role":"precaution"}
    )
    task_summarise = Task(
        description=case,
        expected_output="Combined summary",
        agent=coord,
        context=[task_diagnose, task_treat, task_prec],
        kwargs={"role":"summary"}
    )

    return task_diagnose, task_treat, task_prec, task_summarise

# ── 5. Parse coordinator output ────────────────────────────────────
def parse_summary(text):
    sections = {"diagnosis":"", "treatment":"", "precautions":""}
    diag_m  = re.search(r"##?\s*Diagnos(?:is|es)(.*?)(?=##?\s*Treatment|##?\s*Precaution|$)", text, re.S | re.I)
    treat_m = re.search(r"##?\s*Treatment(.*?)(?=##?\s*Precaution|$)", text, re.S | re.I)
    prec_m  = re.search(r"##?\s*Precaution(.*?)$", text, re.S | re.I)
    if diag_m: sections["diagnosis"] = diag_m.group(1).strip()
    if treat_m: sections["treatment"] = treat_m.group(1).strip()
    if prec_m: sections["precautions"] = prec_m.group(1).strip()
    if not any(sections.values()):
        sections["diagnosis"] = text.strip()
    return sections

# ── 6. Main orchestration ──────────────────────────────────────────
def analyze(symptoms, age, gender):
    if not symptoms.strip():
        return "⚠️ Enter symptoms.", "", ""
    diag, treat, prec, coord = make_agents()
    tasks = make_tasks(symptoms, age, gender, diag, treat, prec, coord)
    crew = Crew(
        agents=[diag, treat, prec, coord],
        tasks=list(tasks),
        process=Process.sequential,
        verbose=False
    )
    result = crew.kickoff()
    sections = parse_summary(str(result))
    disclaimer = "\n\n---\n> ⚠️ **Disclaimer:** AI-generated output, not medical advice."
    return (
        f"### Diagnosis\n{sections['diagnosis']}{disclaimer}",
        f"### Treatment\n{sections['treatment']}{disclaimer}",
        f"### Precautions\n{sections['precautions']}{disclaimer}"
    )

def clear_all():
    return "", "", "", "", ""

# ── 7. Gradio UI ───────────────────────────────────────────────────
with gr.Blocks(title="MedAgent — CrewAI + BioGPT") as demo:
    with gr.Row():
        with gr.Column(scale=1):
            symptoms_in = gr.Textbox(label="Symptoms", lines=6)
            age_in = gr.Textbox(label="Age")
            gender_in = gr.Dropdown(label="Gender", choices=["male","female","other"], value="male")
            analyze_btn = gr.Button("Analyze")
            clear_btn = gr.Button("Clear")
        with gr.Column(scale=2):
            diagnosis_out = gr.Markdown()
            treatment_out = gr.Markdown()
            precautions_out = gr.Markdown()

    analyze_btn.click(fn=analyze, inputs=[symptoms_in, age_in, gender_in],
                      outputs=[diagnosis_out, treatment_out, precautions_out])
    clear_btn.click(fn=clear_all, inputs=[], outputs=[symptoms_in, age_in, diagnosis_out, treatment_out, precautions_out])

if __name__=="__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)