File size: 7,397 Bytes
34ef1af 97bb569 b6d4d10 34ef1af 1c60530 b6d4d10 1c60530 fdd8eb7 935baa0 d11f137 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | ---
license: apache-2.0
datasets:
- qiaojin/PubMedQA
- MedAI-COS30018/PubMedQA-map
- MedAI-COS30018/PubmedQA-u
- MedAI-COS30018/PubMedQA-l
- MedAI-COS30018/HealthCareMagic
- MedAI-COS30018/iCliniq
language:
- en
base_model:
- medalpaca/medalpaca-7b
- google/medgemma-27b-it
pipeline_tag: question-answering
metrics:
- bertscore: 0.8441
tags:
- medical
- knowledge-distillation
---
# Model Card
## Model Description
**MedSwin-7B-KD** is a high-performance 7B parameter language model for medical question-answering and clinical reasoning. It was created by applying a novel **Dual-Phase Knowledge Distillation (KD)** pipeline to the `medalpaca/medalpaca-7b` base model. Unlike its SFT predecessor, this model leverages the superior knowledge and reasoning capabilities of the larger `google/medgemma-27b-it` model as a "teacher" to guide the training of the smaller, more efficient "student" model. This results in a compact model that captures the clinical acumen of a much larger counterpart.
- **Developed by:** Medical Swinburne University of Technology AI Team
- **Funded by:** [Swinburne University of Technology](https://www.swinburne.edu.au)
- **Base Model (Student):** [medalpaca/medalpaca-7b](https://huggingface.co/medalpaca/medalpaca-7b)
- **Teacher Model:** [google/medgemma-27b-it](https://huggingface.co/google/medgemma-27b-it)
- **Language(s):** English
- **License:** Apache 2.0
### Intended Use
This model is intended for research purposes in the following domains:
* AI-assisted medicine and clinical decision support research.
* Biomedical natural language processing (NLP).
* Exploration of efficient knowledge distillation and model compression in specialized domains.
* Generating high-quality, clinically-grounded synthetic data.
## Training Data
The model was trained on the same curated and augmented collection of medical QA datasets as the SFT version, but the *target outputs* were generated by the teacher model.
- **PubMedQA**: Original and processed (map, u, l) variants for factoid and research-oriented questions.
- **HealthCareMagic** & **iCliniq**: Real-world patient-doctor interactions from online portals.
### Data Curation & Knowledge Distillation Pipeline
The training pipeline was fundamentally redesigned to center on knowledge distillation, moving beyond simple paraphrasing to focus on transferring deep reasoning patterns.
| Stage | Purpose | Methodology & Quality Control |
| :--- | :--- | :--- |
| **A. Augmented Query Generation** | Create a diverse set of high-quality input prompts. | Utilizes the same multi-model paraphrasing, back-translation, and style standardization pipeline from the SFT model to generate a rich variety of instructions and inputs. |
| **B. Teacher Forcing & Output Generation** | Generate "gold-standard" responses using the superior teacher model. | **Teacher Model:** `google/medgemma-27b-it`. <br/> **Generation Strategy:** Low-temperature sampling with contrastive decoding to produce confident, factually-dense, and well-structured answers. <br/> **Input:** The entire augmented set of `(Instruction, Input)` pairs from Stage A. |
| **C. Response Filtering & Alignment** | Ensure the teacher's outputs are of the highest quality for student training. | **Factual Consistency Check:** Cross-referencing key medical claims against the original context. <br/> **Style Alignment:** Enforcing the neutral, professional clinical tone. <br/> **Complexity Pruning:** Removing outputs that are overly verbose or rely on reasoning chains too complex for the student model to learn effectively. |
| **D. Dual-Phase Knowledge Distillation** | Transfer knowledge from teacher to student. | **Phase 1 (Response Mimicking):** The student model is trained to directly reproduce the teacher's filtered outputs, learning its style and factual presentation. <br/> **Phase 2 (Logit Matching):** The student is trained to align its internal probability distributions (logits) with the teacher's for the same input, capturing the teacher's "thinking process" and confidence calibration. |
| **E. Quality Assurance** | Ensure the final training pairs are optimal for distillation. | **F1. Data Cleaning:** PHI removal; MD5-based deduplication. <br/> **F2. KD-Specific Validation:** Checking for alignment between query complexity and response depth; ensuring student-trainable reasoning patterns. |
## Output Format
All training data was formatted into the same standardized SFT structure, but the outputs are now teacher-generated:
```
### Instruction:
{Task descriptor and/or user question with context}
### Input:
{Additional user question or context, if any}
### Output:
{The teacher model's (MedGemma-27b) target response}
```
Each data point includes metadata tags for its augmentation source and a `distilled_from: medgemma-27b` tag.
## Usage
You can load and use the model with the Hugging Face `transformers` library, identical to the SFT version but with potentially improved performance.
```python
import transformers
model_id = "MedAI-COS30018/MedSwin-7B-KD"
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
device_map="auto", # Use GPU if available
)
# Format your input according to the training template
instruction = "Based on the provided context, what is the most likely diagnosis?"
context = "A 45-year-old male presents with acute, crushing substernal chest pain radiating to the left arm, associated with diaphoresis and nausea for the past hour."
formatted_prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{context}\n\n### Output:\n"
# Generate a response
sequences = pipeline(
formatted_prompt,
max_new_tokens=256,
do_sample=True,
temperature=0.3,
top_p=0.9,
eos_token_id=pipeline.tokenizer.eos_token_id,
)
print(sequences[0]['generated_text'])
```
## Bias, Risks, and Limitations
The model inherits and may amplify biases present in its base model, teacher model, and training data. These can include:
* **Demographic Biases:** Biases related to race, gender, age, or socioeconomic status based on patterns in the source data.
* **Clinical Biases:** Potential over-representation of certain conditions, treatments, or clinical perspectives.
* **Factual Accuracy:** While the teacher model is highly capable, it is not infallible. The distilled model may propagate or even amplify any errors made by the teacher. It is not a certified medical knowledge base and can generate incorrect or outdated information.
* **Safe Deployment:** Use a **Human-in-the-Loop** (HITL) system for any real-world application. Outputs **must** be verified by a qualified healthcare professional. **Do not use for direct patient care without rigorous clinical validation.**
## Technical Specifications & Evaluation
* **Model Architecture:** Based on LLaMA, fine-tuned via Dual-Phase Knowledge Distillation.
* **Model Size:** 7 Billion parameters.
* **Teacher Model Size:** 27 Billion parameters.
* **Input Format:** Instruction-Input-Output structure.
* **Key Metric:**
* **BERTScore (F1):** 0.84.
* [Benchmark Dataset](https://huggingface.co/datasets/MedSwin/MedQuAD_Benchmark)
* [Benchmark Logs](https://github.com/MedSwin/Finetuning/tree/main/benchmarks/MedQuAD_benchmark_runs)
*
> Review all model metrics benchmark via [Benchmark Document Preview](https://hackmd.io/@ngFNmXW1RVOfNb7b3NYBJg/model_review). |