ClinicDx V1

ClinicDx V1 is a fine-tuned multimodal clinical decision support (CDS) model based on google/medgemma-4b-it. It is trained to generate structured, evidence-grounded clinical assessments from patient presentations, integrating a retrieval-augmented knowledge base (KB) pipeline and an audio input pathway for voice-driven clinical observation extraction.

ClinicDx is an open-source trimodal inference system for edge clinical AI β€” combining a medical ASR encoder, a learned audio projector, and a fine-tuned 4B clinical LLM in a single llama.cpp binary, deployable fully offline on consumer hardware. For research contributions, open problems, and comparison to prior work, see RESEARCH.md on GitHub.

GitHub Β· Website Β· npm Package

This repository contains all four artifacts needed to run the full system with llama-server:

File Size Description
clinicdx-v1-q8.gguf 3.9 GB ClinicDx V1 language model (Q8_0 quantisation)
medasr-encoder.gguf 401 MB MedASR Conformer encoder (frozen, 105M params)
audio-projector-v3-best.gguf 46 MB AudioProjector v3 β€” best checkpoint (step 40000, val LM 0.1042)
who_knowledge_vec_v2.mv2 1.1 GB WHO/MSF knowledge base v2.1 (27,860 chunks, BM25 + semantic hybrid)

Knowledge Base

The who_knowledge_vec_v2.mv2 index contains 27,860 chunks from WHO and MSF clinical guidelines, built via a Docling + HybridChunker pipeline with safety keyword detection for life-threatening conditions.

The retrieval pipeline uses hybrid search (BM25 + EmbedGemma 300M semantic, merged via Reciprocal Rank Fusion) followed by a 4-slot clinical intent reranker that extracts condition, severity, population, and task from each query and rescores hits with multiplicative penalties (Γ—0.12 for off-condition, Γ—0.35 for severity mismatch, Γ—0.30 for wrong population) and additive boosts for slot alignment. The reranker covers 30+ clinical conditions with inclusion/exclusion patterns and condition-specific overrides.

During CDS inference, the model controls its own retrieval via a multi-turn ReAct loop β€” emitting <KB_QUERY> tags that the middleware resolves against this index and injects as <KB_RESULT> context, for up to 5 retrieval turns per request.


Model Description

ClinicDx V1 is a LoRA-fine-tuned and fully merged version of MedGemma 4B Instruct. It was trained with input masking (only model output turns are trained; user and KB turns are masked) on a quality-filtered dataset of 27,592 clinical conversations augmented with KB retrieval.

The model generates structured 6-section responses with inline citations sourced from retrieved knowledge base content.


Training Details

Parameter Value
Base model google/medgemma-4b-it
Training method LoRA (r=64, Ξ±=128) β€” fully merged
Training cycle Cycle 2 β€” Masked
Dataset production_run_v2 (quality-filtered)
Train samples 27,592
Validation samples 1,452
Input masking 64% masked (user + KB turns), 36% trainable (model turns only)
Best eval loss 0.4758
Best eval accuracy 86.25%
Best checkpoint Step 4000
Epochs trained 2.32 (early stopped, patience=5)
LoRA target modules q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
LoRA dropout 0.05
Max sequence length 8192
Precision bfloat16

Output Schema

Each response is structured into 6 sections:

  1. Alert Level β€” Urgency classification (e.g. Routine / Urgent / Emergency)
  2. Clinical Assessment β€” Summary of presenting findings and clinical reasoning
  3. Differential Considerations β€” Ranked differential diagnoses with rationale
  4. Recommended Actions β€” Investigations and immediate management steps
  5. Safety Alerts β€” Red-flag signs, drug interactions, contraindications
  6. Key Points β€” Concise summary for handover or documentation

KB integration uses an EXTRACTED/BANKED think-block pattern with inline [WHO: source] citations in the final response.


Architecture

  • Base: Gemma3ForConditionalGeneration (4.3B parameters)
  • LoRA adapters: Merged into base weights β€” no adapter files needed at inference
  • Vision tower: Present (inherited from MedGemma base, frozen, not used in CDS or Scribe)
  • Audio projector: Included in this repository as audio-projector-v3-best.gguf

Audio Projector & Voice Input

The ClinicDx production server combines this model with a MedASR encoder and a lightweight AudioProjector to enable voice-to-CDS inference. The architecture mirrors how Gemma3 integrates vision β€” a frozen encoder feeds a trainable projector whose output is injected into the LLM's embedding sequence.

Full System Architecture

Patient audio (16kHz mono WAV)
        |
        v
MedASR Conformer Encoder  (frozen, 105M params)
  Mel spectrogram (128 bins, hop=160, n_fft=512)
  Natural log + 1e-5 clamp normalisation
  -> 17-layer Conformer encoder, 512 hidden dim
  -> [B, T_enc, 512]  (T_enc β‰ˆ audio_seconds Γ— 50)
        |
        v
AudioProjector v3  (trainable, 11,806,720 params)
  Frame stacking  k=4:   [B, T_enc, 512] -> [B, T_enc/4, 2048]
  Linear(2048 β†’ 2560, bias=False)
  RMSNorm(2560)
  GELU
  Linear(2560 β†’ 2560, bias=False)
  LayerNorm(2560)  [ln_final β€” added in v3]
  Pad (learned padding embedding) or truncate to 64 tokens
  -> [B, 64, 2560]   (MedGemma embedding space)
        |
        v
ClinicDx V1 Language Model  (4.3B params)
  <image_soft_token> Γ— 64 placeholders in the text sequence
  are replaced with projected audio embeddings via masked_scatter
  (reuses Gemma3's image token injection mechanism)
        |
        v
Structured medical observations (key: value format)

AudioProjector Architecture Detail

The Gemma3AudioProjector is the only trainable component during audio projector training. It is a 2-layer MLP with frame stacking and a final LayerNorm:

class Gemma3AudioProjector(nn.Module):
    # Input:  [B, T_enc, 512]  β€” MedASR encoder output
    # Step 1: Frame stacking (k=4): [B, T_enc/4, 2048]
    # Step 2: proj = Sequential(
    #             Linear(2048 β†’ 2560, bias=False),
    #             RMSNorm(2560),
    #             GELU(),
    #             Linear(2560 β†’ 2560, bias=False),
    #         )
    # Step 3: ln_final = LayerNorm(2560)
    # Step 4: Pad (learned audio_padding_emb) or truncate to 64 tokens
    # Output: [B, 64, 2560]

Trainable parameters breakdown (11,806,720 total):

Tensor Shape Parameters
audio_padding_emb [1, 1, 2560] 2,560
proj.0.weight (Linear 1) [2560, 2048] 5,242,880
proj.1.weight (RMSNorm) [2560] 2,560
proj.3.weight (Linear 2) [2560, 2560] 6,553,600
ln_final.weight (LayerNorm) [2560] 2,560
ln_final.bias (LayerNorm) [2560] 2,560

Token budget per audio duration (16kHz, hop=160, ~50 frames/sec encoder output, 4Γ— stacking):

Audio length Encoder frames After stacking After pad/trunc
1 second ~50 ~13 64 (padded)
3 seconds ~150 ~38 64 (padded)
5 seconds ~250 ~63 64 (padded)
10 seconds ~500 ~125 64 (truncated)
20 seconds ~1000 ~250 64 (truncated)

AudioProjector Training

The projector was trained independently from the CDS LoRA on a clinical audio dataset:

Parameter Value
Config train_config_mvp.yaml
Trainable params 11,806,720 (projector only)
Frozen params ~4.5B (base model + MedASR encoder)
Training data 47,500 pre-computed audio clips (MVP audio dataset)
Validation data 2,500 clips (5% held-out split)
Batch size 8
Learning rate 5.0e-4 (AdamW)
Warmup steps 500
Gradient clipping 1.0
Epochs 10
Best checkpoint Step 40000 (epoch 6 of 10)
Best val LM loss 0.1042
Best val key accuracy 84.0%
Precision bfloat16

Validation history:

Step Val LM Loss Key Accuracy
5,000 0.1496 77.1%
10,000 0.1262 77.9%
15,000 0.1198 77.9%
20,000 0.1169 79.4%
25,000 0.1115 84.7%
30,000 0.1094 82.4%
35,000 0.1062 80.9%
40,000 0.1042 βœ“ 84.0%
45,000 0.1123 84.7%
50,000 0.1200 86.3%

Step 40,000 produced the best generalisation (lowest val LM loss). Training continued for an additional 15,000 steps but did not improve val LM loss, indicating overfitting onset. The audio-projector-v3-best.gguf file contains the weights from this step.

Loss functions:

L_lm          β€” Cross-entropy on target output tokens (main loss)
L_contrastive β€” Cosine similarity between projected audio embeddings
                and concept text embeddings (single-phrase clips only)
L_total = L_lm + 0.1 Γ— L_contrastive

Audio Token Details

The audio pathway reuses Gemma3's existing image token injection mechanism. The audio soft token (<image_soft_token>, ID 262144) is repurposed as the audio placeholder. No new tokens are added to the vocabulary.

Token ID Purpose
<start_of_image> 255,999 Begin-of-audio delimiter (reused for audio)
<end_of_image> 256,000 End-of-audio delimiter (reused for audio)
<image_soft_token> 262,144 Audio embedding placeholder (Γ—64 per clip)

The system prompt uses <start_of_audio> / <end_of_audio> as human-readable markers in the text, while the tokenised form uses the image token IDs above for embedding injection.


Running with llama-server (GGUF β€” Recommended)

The fastest deployment path uses the three GGUF files in this repository with a CUDA-enabled llama-server build that includes --medasr-encoder and --audio-proj support.

Prerequisites

  • NVIDIA GPU with β‰₯8 GB VRAM (β‰₯12 GB recommended for full Q8 + encoder + projector)
  • llama-server built with CUDA and MedASR/audio-projector support
  • ffmpeg available on the host (for audio transcoding to 16kHz PCM-16 WAV)

Download GGUFs

pip install huggingface_hub
python - <<'EOF'
from huggingface_hub import snapshot_download
snapshot_download(
    repo_id="ClinicDx1/ClinicDx",
    allow_patterns=["*.gguf"],
    local_dir="./clinicdx-gguf"
)
EOF

Start the Server

llama-server \
  --model             ./clinicdx-gguf/clinicdx-v1-q8.gguf \
  --medasr-encoder    ./clinicdx-gguf/medasr-encoder.gguf \
  --audio-proj        ./clinicdx-gguf/audio-projector-v3-best.gguf \
  --n-gpu-layers      999 \
  --ctx-size          8192 \
  --parallel          1 \
  --threads           8 \
  --host              0.0.0.0 \
  --port              8180

Note: --parallel 1 is required. The audio extraction endpoint (/v1/audio/extract) performs a blocking llama_decode on the shared context; parallel slots cause assertion failures.

Audio Inference (via REST API)

# Transcode browser audio to required format first
ffmpeg -i input.webm -ar 16000 -ac 1 -c:a pcm_s16le output.wav

# Send to the audio extraction endpoint
curl -X POST http://localhost:8180/v1/audio/extract \
  -H "Content-Type: audio/wav" \
  --data-binary @output.wav

The endpoint returns structured key: value observations matching the clinical manifest provided in the system prompt.


Usage (Text CDS β€” no audio)

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "ClinicDx1/ClinicDx"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

prompt = """<start_of_turn>user
Patient: 45-year-old male, 2 days of fever (39.2Β°C), productive cough, right-sided pleuritic chest pain, decreased breath sounds right base.
<end_of_turn>
<start_of_turn>model
"""

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=1024, temperature=0.1, do_sample=True)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Intended Use

  • Clinical decision support for trained healthcare professionals
  • Structured differential diagnosis generation
  • Evidence-grounded treatment planning with KB citations
  • Voice-driven clinical observation extraction in low-resource clinical settings
  • Not intended for direct patient-facing use or autonomous clinical decision making

Limitations and Open Problems

  • No formal clinical validation. Accuracy metrics (86.25% CDS, 84% Scribe key accuracy) are measured on held-out synthetic data, not on real clinical encounters. A prospective evaluation with practicing clinicians is the highest-priority gap.
  • English only. All CDS outputs, Scribe extraction, and KB retrieval operate in English. Multilingual support (Swahili, Amharic, Hausa, Yoruba) is not yet implemented.
  • Audio token norm mismatch. Projected audio token norms are 360-620x larger than text token norms in the LLM embedding space. The current mitigation uses adaptive norm alignment loss, but this remains an active area of investigation.
  • Synthetic training data. Trained on curated synthetic/augmented clinical data β€” real-world performance may vary.
  • KB-dependent for best results. Knowledge base integration requires the ClinicDx retrieval pipeline; standalone use generates structure but without live KB citations.
  • Audio projector trained on synthetic speech. Accuracy on natural conversational speech, accented English, or noisy clinical environments may be lower.
  • Q8 quantization chosen empirically. Q8_0 was selected over Q4 variants because lower quantization degraded structured output behavior (THINK block coherence, KB query emission) more than perplexity. No systematic ablation study has been conducted.
  • No ARM / low-power benchmarks. Validated on x86_64 with NVIDIA GPUs and CPU-only mode. Latency on ARM edge devices is unknown.
  • The model may produce plausible-sounding but incorrect clinical information β€” always verify with a qualified clinician.

License

This model is released under the Gemma Terms of Use. Use is subject to those terms.

Downloads last month
639
Safetensors
Model size
4B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for ClinicDx1/ClinicDx

Adapter
(91)
this model