Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import re | |
| from transformers import BioGptTokenizer, BioGptForCausalLM, pipeline | |
| from crewai import Agent, Task, Crew, Process | |
| from crewai.llm import BaseLLM | |
| # ── 1. Load BioGPT once ───────────────────────────────────────────── | |
| MODEL_NAME = "microsoft/biogpt" | |
| tokenizer = BioGptTokenizer.from_pretrained(MODEL_NAME) | |
| model = BioGptForCausalLM.from_pretrained(MODEL_NAME) | |
| model.eval() | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=250, | |
| do_sample=True, | |
| temperature=0.65, | |
| top_p=0.92, | |
| repetition_penalty=1.35, | |
| ) | |
| def biogpt_complete(prompt: str, max_new=200) -> str: | |
| raw = pipe(prompt, max_new_tokens=max_new)[0]["generated_text"] | |
| gen = raw[len(prompt):].strip() | |
| if "." in gen: | |
| gen = gen[: gen.rfind(".")+1] | |
| return gen or "Insufficient data generated." | |
| # ── 2. Wrap BioGPT in BaseLLM for CrewAI ───────────────────────────── | |
| class BioGPTLLM(BaseLLM): | |
| def __init__(self): | |
| super().__init__(model="biogpt-local", temperature=0.65) | |
| def call(self, messages, **kwargs): | |
| # Extract the clinical case from the CrewAI task | |
| if isinstance(messages, str): | |
| case = messages | |
| else: | |
| case = " ".join([m.get("content","") for m in messages]) | |
| m = re.search(r"(\d+-year-old .+?\.)", case) | |
| case = m.group(1).strip() if m else case[:200] | |
| role = kwargs.get("role", "general") | |
| if role == "diagnosis": | |
| prompt = f"A {case} The differential diagnosis includes:" | |
| elif role == "treatment": | |
| prompt = f"For a patient with {case} The recommended treatment includes:" | |
| elif role == "precaution": | |
| prompt = f"For a patient with {case} Key precautions and follow-up include:" | |
| else: | |
| prompt = f"The clinical findings for {case} suggest:" | |
| return biogpt_complete(prompt) | |
| biogpt_llm = BioGPTLLM() | |
| # ── 3. Agents ─────────────────────────────────────────────────────── | |
| def make_agents(): | |
| diagnostician = Agent( | |
| role="Medical Diagnostician", | |
| backstory="You are a medical expert specializing in diagnosing patients based on symptoms.", | |
| goal="Produce a differential diagnosis list for the patient.", | |
| llm=biogpt_llm, | |
| verbose=False, | |
| allow_delegation=False, | |
| ) | |
| treatment_planner = Agent( | |
| role="Treatment Planner", | |
| backstory="You are a medical expert specializing in suggesting treatments for patients.", | |
| goal="Recommend evidence-based treatment for the patient.", | |
| llm=biogpt_llm, | |
| verbose=False, | |
| allow_delegation=False, | |
| ) | |
| precaution_advisor = Agent( | |
| role="Precaution Advisor", | |
| backstory="You are a medical expert who advises patients on safety, lifestyle, and follow-up.", | |
| goal="Provide safety, lifestyle, and follow-up guidance.", | |
| llm=biogpt_llm, | |
| verbose=False, | |
| allow_delegation=False, | |
| ) | |
| coordinator = Agent( | |
| role="Medical Coordinator", | |
| backstory="You combine outputs from other medical agents into a concise summary.", | |
| goal="Combine outputs into a concise three-section summary.", | |
| llm=biogpt_llm, | |
| verbose=False, | |
| allow_delegation=False, | |
| ) | |
| return diagnostician, treatment_planner, precaution_advisor, coordinator | |
| # ── 4. CrewAI Tasks ──────────────────────────────────────────────── | |
| def make_tasks(symptoms, age, gender, diag, treat, prec, coord): | |
| case = f"{age}-year-old {gender} presenting with: {symptoms}." | |
| task_diagnose = Task( | |
| description=case, | |
| expected_output="Differential diagnoses list", | |
| agent=diag, | |
| kwargs={"role":"diagnosis"} | |
| ) | |
| task_treat = Task( | |
| description=case, | |
| expected_output="Treatment recommendations", | |
| agent=treat, | |
| context=[task_diagnose], | |
| kwargs={"role":"treatment"} | |
| ) | |
| task_prec = Task( | |
| description=case, | |
| expected_output="Precautions", | |
| agent=prec, | |
| context=[task_diagnose], | |
| kwargs={"role":"precaution"} | |
| ) | |
| task_summarise = Task( | |
| description=case, | |
| expected_output="Combined summary", | |
| agent=coord, | |
| context=[task_diagnose, task_treat, task_prec], | |
| kwargs={"role":"summary"} | |
| ) | |
| return task_diagnose, task_treat, task_prec, task_summarise | |
| # ── 5. Parse coordinator output ──────────────────────────────────── | |
| def parse_summary(text): | |
| sections = {"diagnosis":"", "treatment":"", "precautions":""} | |
| diag_m = re.search(r"##?\s*Diagnos(?:is|es)(.*?)(?=##?\s*Treatment|##?\s*Precaution|$)", text, re.S | re.I) | |
| treat_m = re.search(r"##?\s*Treatment(.*?)(?=##?\s*Precaution|$)", text, re.S | re.I) | |
| prec_m = re.search(r"##?\s*Precaution(.*?)$", text, re.S | re.I) | |
| if diag_m: sections["diagnosis"] = diag_m.group(1).strip() | |
| if treat_m: sections["treatment"] = treat_m.group(1).strip() | |
| if prec_m: sections["precautions"] = prec_m.group(1).strip() | |
| if not any(sections.values()): | |
| sections["diagnosis"] = text.strip() | |
| return sections | |
| # ── 6. Main orchestration ────────────────────────────────────────── | |
| def analyze(symptoms, age, gender): | |
| if not symptoms.strip(): | |
| return "⚠️ Enter symptoms.", "", "" | |
| diag, treat, prec, coord = make_agents() | |
| tasks = make_tasks(symptoms, age, gender, diag, treat, prec, coord) | |
| crew = Crew( | |
| agents=[diag, treat, prec, coord], | |
| tasks=list(tasks), | |
| process=Process.sequential, | |
| verbose=False | |
| ) | |
| result = crew.kickoff() | |
| sections = parse_summary(str(result)) | |
| disclaimer = "\n\n---\n> ⚠️ **Disclaimer:** AI-generated output, not medical advice." | |
| return ( | |
| f"### Diagnosis\n{sections['diagnosis']}{disclaimer}", | |
| f"### Treatment\n{sections['treatment']}{disclaimer}", | |
| f"### Precautions\n{sections['precautions']}{disclaimer}" | |
| ) | |
| def clear_all(): | |
| return "", "", "", "", "" | |
| # ── 7. Gradio UI ─────────────────────────────────────────────────── | |
| with gr.Blocks(title="MedAgent — CrewAI + BioGPT") as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| symptoms_in = gr.Textbox(label="Symptoms", lines=6) | |
| age_in = gr.Textbox(label="Age") | |
| gender_in = gr.Dropdown(label="Gender", choices=["male","female","other"], value="male") | |
| analyze_btn = gr.Button("Analyze") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Column(scale=2): | |
| diagnosis_out = gr.Markdown() | |
| treatment_out = gr.Markdown() | |
| precautions_out = gr.Markdown() | |
| analyze_btn.click(fn=analyze, inputs=[symptoms_in, age_in, gender_in], | |
| outputs=[diagnosis_out, treatment_out, precautions_out]) | |
| clear_btn.click(fn=clear_all, inputs=[], outputs=[symptoms_in, age_in, diagnosis_out, treatment_out, precautions_out]) | |
| if __name__=="__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |