|
|
--- |
|
|
license: apache-2.0 |
|
|
datasets: |
|
|
- qiaojin/PubMedQA |
|
|
- MedAI-COS30018/PubMedQA-map |
|
|
- MedAI-COS30018/PubmedQA-u |
|
|
- MedAI-COS30018/PubMedQA-l |
|
|
- MedAI-COS30018/HealthCareMagic |
|
|
- MedAI-COS30018/iCliniq |
|
|
language: |
|
|
- en |
|
|
base_model: |
|
|
- medalpaca/medalpaca-7b |
|
|
- google/medgemma-27b-it |
|
|
pipeline_tag: question-answering |
|
|
metrics: |
|
|
- bertscore: 0.8441 |
|
|
tags: |
|
|
- medical |
|
|
- knowledge-distillation |
|
|
--- |
|
|
|
|
|
# Model Card |
|
|
|
|
|
## Model Description |
|
|
|
|
|
**MedSwin-7B-KD** is a high-performance 7B parameter language model for medical question-answering and clinical reasoning. It was created by applying a novel **Dual-Phase Knowledge Distillation (KD)** pipeline to the `medalpaca/medalpaca-7b` base model. Unlike its SFT predecessor, this model leverages the superior knowledge and reasoning capabilities of the larger `google/medgemma-27b-it` model as a "teacher" to guide the training of the smaller, more efficient "student" model. This results in a compact model that captures the clinical acumen of a much larger counterpart. |
|
|
|
|
|
- **Developed by:** Medical Swinburne University of Technology AI Team |
|
|
- **Funded by:** [Swinburne University of Technology](https://www.swinburne.edu.au) |
|
|
- **Base Model (Student):** [medalpaca/medalpaca-7b](https://huggingface.co/medalpaca/medalpaca-7b) |
|
|
- **Teacher Model:** [google/medgemma-27b-it](https://huggingface.co/google/medgemma-27b-it) |
|
|
- **Language(s):** English |
|
|
- **License:** Apache 2.0 |
|
|
|
|
|
### Intended Use |
|
|
|
|
|
This model is intended for research purposes in the following domains: |
|
|
* AI-assisted medicine and clinical decision support research. |
|
|
* Biomedical natural language processing (NLP). |
|
|
* Exploration of efficient knowledge distillation and model compression in specialized domains. |
|
|
* Generating high-quality, clinically-grounded synthetic data. |
|
|
|
|
|
## Training Data |
|
|
|
|
|
The model was trained on the same curated and augmented collection of medical QA datasets as the SFT version, but the *target outputs* were generated by the teacher model. |
|
|
- **PubMedQA**: Original and processed (map, u, l) variants for factoid and research-oriented questions. |
|
|
- **HealthCareMagic** & **iCliniq**: Real-world patient-doctor interactions from online portals. |
|
|
|
|
|
### Data Curation & Knowledge Distillation Pipeline |
|
|
|
|
|
The training pipeline was fundamentally redesigned to center on knowledge distillation, moving beyond simple paraphrasing to focus on transferring deep reasoning patterns. |
|
|
|
|
|
| Stage | Purpose | Methodology & Quality Control | |
|
|
| :--- | :--- | :--- | |
|
|
| **A. Augmented Query Generation** | Create a diverse set of high-quality input prompts. | Utilizes the same multi-model paraphrasing, back-translation, and style standardization pipeline from the SFT model to generate a rich variety of instructions and inputs. | |
|
|
| **B. Teacher Forcing & Output Generation** | Generate "gold-standard" responses using the superior teacher model. | **Teacher Model:** `google/medgemma-27b-it`. <br/> **Generation Strategy:** Low-temperature sampling with contrastive decoding to produce confident, factually-dense, and well-structured answers. <br/> **Input:** The entire augmented set of `(Instruction, Input)` pairs from Stage A. | |
|
|
| **C. Response Filtering & Alignment** | Ensure the teacher's outputs are of the highest quality for student training. | **Factual Consistency Check:** Cross-referencing key medical claims against the original context. <br/> **Style Alignment:** Enforcing the neutral, professional clinical tone. <br/> **Complexity Pruning:** Removing outputs that are overly verbose or rely on reasoning chains too complex for the student model to learn effectively. | |
|
|
| **D. Dual-Phase Knowledge Distillation** | Transfer knowledge from teacher to student. | **Phase 1 (Response Mimicking):** The student model is trained to directly reproduce the teacher's filtered outputs, learning its style and factual presentation. <br/> **Phase 2 (Logit Matching):** The student is trained to align its internal probability distributions (logits) with the teacher's for the same input, capturing the teacher's "thinking process" and confidence calibration. | |
|
|
| **E. Quality Assurance** | Ensure the final training pairs are optimal for distillation. | **F1. Data Cleaning:** PHI removal; MD5-based deduplication. <br/> **F2. KD-Specific Validation:** Checking for alignment between query complexity and response depth; ensuring student-trainable reasoning patterns. | |
|
|
|
|
|
## Output Format |
|
|
|
|
|
All training data was formatted into the same standardized SFT structure, but the outputs are now teacher-generated: |
|
|
|
|
|
``` |
|
|
### Instruction: |
|
|
{Task descriptor and/or user question with context} |
|
|
|
|
|
### Input: |
|
|
{Additional user question or context, if any} |
|
|
|
|
|
### Output: |
|
|
{The teacher model's (MedGemma-27b) target response} |
|
|
``` |
|
|
|
|
|
Each data point includes metadata tags for its augmentation source and a `distilled_from: medgemma-27b` tag. |
|
|
|
|
|
## Usage |
|
|
|
|
|
You can load and use the model with the Hugging Face `transformers` library, identical to the SFT version but with potentially improved performance. |
|
|
|
|
|
```python |
|
|
import transformers |
|
|
|
|
|
model_id = "MedAI-COS30018/MedSwin-7B-KD" |
|
|
pipeline = transformers.pipeline( |
|
|
"text-generation", |
|
|
model=model_id, |
|
|
device_map="auto", # Use GPU if available |
|
|
) |
|
|
|
|
|
# Format your input according to the training template |
|
|
instruction = "Based on the provided context, what is the most likely diagnosis?" |
|
|
context = "A 45-year-old male presents with acute, crushing substernal chest pain radiating to the left arm, associated with diaphoresis and nausea for the past hour." |
|
|
formatted_prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{context}\n\n### Output:\n" |
|
|
|
|
|
# Generate a response |
|
|
sequences = pipeline( |
|
|
formatted_prompt, |
|
|
max_new_tokens=256, |
|
|
do_sample=True, |
|
|
temperature=0.3, |
|
|
top_p=0.9, |
|
|
eos_token_id=pipeline.tokenizer.eos_token_id, |
|
|
) |
|
|
print(sequences[0]['generated_text']) |
|
|
``` |
|
|
|
|
|
## Bias, Risks, and Limitations |
|
|
|
|
|
The model inherits and may amplify biases present in its base model, teacher model, and training data. These can include: |
|
|
* **Demographic Biases:** Biases related to race, gender, age, or socioeconomic status based on patterns in the source data. |
|
|
* **Clinical Biases:** Potential over-representation of certain conditions, treatments, or clinical perspectives. |
|
|
* **Factual Accuracy:** While the teacher model is highly capable, it is not infallible. The distilled model may propagate or even amplify any errors made by the teacher. It is not a certified medical knowledge base and can generate incorrect or outdated information. |
|
|
* **Safe Deployment:** Use a **Human-in-the-Loop** (HITL) system for any real-world application. Outputs **must** be verified by a qualified healthcare professional. **Do not use for direct patient care without rigorous clinical validation.** |
|
|
|
|
|
## Technical Specifications & Evaluation |
|
|
|
|
|
* **Model Architecture:** Based on LLaMA, fine-tuned via Dual-Phase Knowledge Distillation. |
|
|
* **Model Size:** 7 Billion parameters. |
|
|
* **Teacher Model Size:** 27 Billion parameters. |
|
|
* **Input Format:** Instruction-Input-Output structure. |
|
|
* **Key Metric:** |
|
|
* **BERTScore (F1):** 0.84. |
|
|
|
|
|
* [Benchmark Dataset](https://huggingface.co/datasets/MedSwin/MedQuAD_Benchmark) |
|
|
* [Benchmark Logs](https://github.com/MedSwin/Finetuning/tree/main/benchmarks/MedQuAD_benchmark_runs) |
|
|
* |
|
|
> Review all model metrics benchmark via [Benchmark Document Preview](https://hackmd.io/@ngFNmXW1RVOfNb7b3NYBJg/model_review). |