File size: 4,158 Bytes
a8a971d
 
 
3991281
a8a971d
 
 
3bf3c5d
19b2a18
3bf3c5d
19b2a18
a8a971d
b8a194c
290899c
b8a194c
a8a971d
b8a194c
 
5b54a4e
b8a194c
 
5b54a4e
b8a194c
 
5b54a4e
b8a194c
 
 
 
 
5b54a4e
b8a194c
 
5b54a4e
b8a194c
a8a971d
 
 
6b91224
 
 
3bf3c5d
6b91224
a8a971d
 
 
 
 
 
6b91224
 
290899c
a8a971d
b8a194c
a8a971d
19b2a18
 
 
5b54a4e
 
19b2a18
 
 
 
 
 
 
 
 
b8a194c
 
a8a971d
b8a194c
a8a971d
 
6b91224
19b2a18
a8a971d
 
6b91224
a8a971d
6b91224
a8a971d
6b91224
a8a971d
 
 
 
 
 
 
 
 
 
 
6b91224
 
 
 
 
 
 
 
3bf3c5d
6b91224
3bf3c5d
69efb95
6b91224
 
a8a971d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
DeepMed-R1: Medical Reasoning AI Demo
"""
import os
import gradio as gr
from huggingface_hub import InferenceClient

HF_TOKEN = os.environ.get("HF_TOKEN")
client = InferenceClient(token=HF_TOKEN)

MODEL = "Qwen/Qwen2.5-72B-Instruct"

SYSTEM_PROMPT = """You are DeepMed-R1, a medical reasoning AI trained with GRPO on AMD MI300X.

For EVERY question you MUST follow this format exactly:

**Step 1 - Key Findings:**
[List the important clinical features from the question]

**Step 2 - Differential Diagnosis:**
[List 3-4 possible diagnoses ranked by likelihood]

**Step 3 - Pathophysiology:**
[Explain WHY the most likely diagnosis causes these symptoms]

**Step 4 - Option Analysis:**
- A: [evaluate this option]
- B: [evaluate this option]
- C: [evaluate this option]
- D: [evaluate this option]

**Step 5 - Final Answer:**
The answer is **X** because [one sentence justification].

IMPORTANT: Always show ALL 5 steps. Never skip steps. Be detailed and specific."""


EXAMPLES = [
    ["A 65-year-old male with hypertension presents with sudden 'worst headache of my life,' neck stiffness, photophobia. BP 180/100. Most likely diagnosis?\nA. Migraine\nB. Subarachnoid hemorrhage\nC. Meningitis\nD. Tension headache"],
    ["A 28-year-old woman: fatigue, weight gain, cold intolerance. TSH 12 mIU/L, Free T4 0.5 ng/dL. Initial treatment?\nA. Levothyroxine\nB. Liothyronine\nC. Methimazole\nD. Radioactive iodine"],
    ["3-month-old with projectile non-bilious vomiting, olive-shaped RUQ mass, metabolic alkalosis. Diagnosis?\nA. Pyloric stenosis\nB. Intussusception\nC. Malrotation\nD. Hirschsprung disease"],
    ["55-year-old diabetic: RUQ pain, fever 39.2C, jaundice (Charcot triad). WBC 18K, bilirubin 5.2, CBD stone. Next step?\nA. Cholecystectomy\nB. ERCP with sphincterotomy\nC. MRCP\nD. PTC"],
    ["22-year-old post-MVC: left chest pain, absent breath sounds left, trachea deviated right, JVD, BP 80/50. Immediate management?\nA. Chest X-ray\nB. CT chest\nC. Needle decompression left chest\nD. Intubation"],
]


def respond(message, history):
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    for h in history:
        if h[0]: messages.append({"role": "user", "content": h[0]})
        if h[1]: messages.append({"role": "assistant", "content": h[1]})
    messages.append({"role": "user", "content": message})

    response = ""
    try:
        stream = client.chat_completion(
            model=MODEL,
            messages=messages,
            max_tokens=4000,
            temperature=0.4,
            top_p=0.95,
            stream=True,
        )
        for chunk in stream:
            if chunk.choices and len(chunk.choices) > 0:
                delta = chunk.choices[0].delta
                if hasattr(delta, "content") and delta.content:
                    response += delta.content
                    yield response
        if not response:
            yield "No response received. Please try again."
    except Exception as e:
        yield f"⚠️ Error: {str(e)}"


with gr.Blocks(title="DeepMed-R1", theme=gr.themes.Soft(primary_hue="blue")) as demo:
    gr.Markdown("""
# πŸ₯ DeepMed-R1: Medical Reasoning AI

**Systematic clinical reasoning powered by GRPO + Multi-Objective Clinical Rewards**

Built for **AMD Developer Hackathon 2026** β€” Track 2: Fine-Tuning on AMD GPUs

> ⚠️ Research demo only. Not for clinical use.
---
    """)

    chatbot = gr.ChatInterface(
        fn=respond,
        examples=EXAMPLES,
        title="",
    )

    gr.Markdown("""
---
### πŸ”¬ Architecture

| Component | Detail |
|-----------|--------|
| Base Model | Qwen3 (RL-trained reasoning) |
| Training | GRPO + DAPO loss + CRPO rewards |
| Innovations | iGRPO (Feb 2026) + AERO (Feb 2026) + Curriculum (Mar 2026) |
| Hardware | AMD MI300X (192GB HBM3) |
| Reward | Accuracy (2.0) + Reasoning (1.0) + Consistency (0.5) + Length (0.3) |

πŸ“„ [Code & Training](https://huggingface.co/asusf15/DeepMed-R1) |
πŸ“š [Gazal-R1](https://arxiv.org/abs/2506.21594) |
πŸ“š [Clinical-R1](https://arxiv.org/abs/2512.00601) |
πŸ“š [iGRPO](https://arxiv.org/abs/2602.09000)
    """)

if __name__ == "__main__":
    demo.launch()