ThaiLLM-30B-IQ (Medical)
Model Description
ThaiLLM-30B-IQ (Medical) is a Thai-language large language model specialized for medical information query and citation-grounded question answering.
The model is fine-tuned from ThaiLLM-30B using supervised fine-tuning (SFT) with medical-domain data. It is designed for retrieval-augmented generation (RAG) workflows, where answers must be generated only from provided medical contexts and returned with explicit citations.
This model does not include a retriever and must be used with externally supplied medical documents.
Intended Use
Suitable for
- Medical RAG systems (Thai)
- Medical document question answering
- Citation-grounded medical QA
- Medical education and evaluation
Not suitable for
- Medical diagnosis or treatment decisions
- Patient-facing clinical advice
- Emergency or critical-care use
Model Details
- Base model: ThaiLLM-30B
- Domain: Medical
- Model type: Decoder-only causal language model
- Fine-tuning: Supervised fine-tuning (LoRA)
- Language: Thai
- Precision: bfloat16
Performance (Medical Information Query Setting)
| Model | Response (BLEU) | Citations (Jaccard) |
|---|---|---|
| ThaiLLM-30B | 0.406 | 0.1786 |
| ThaiLLM-8B-SFT-IQ (Medical) | 0.4331 | 0.5458 |
Training
- Framework: LLaMA-Factory
- Method: LoRA-based supervised fine-tuning
| Hyperparameter | Value |
|---|---|
| Learning rate | 2e-4 |
| LoRA rank | 16 |
| LoRA alpha | 16 |
| Sequence length | 4096 |
| Epochs | 3 |
| Batch size | 4 |
Training data consists of Thai medical question–answer pairs with context grounding and citation supervision.
- Training Script
llamafactory-cli train \ --stage sft \ --do_train True \ --model_name_or_path ${MODEL_PATH} \ --preprocessing_num_workers 16 \ --deepspeed ./examples/deepspeed/ds_z2_config.json \ --finetuning_type lora \ --template qwen3 \ --flash_attn auto \ --dataset thaillm-SFT \ --cutoff_len 4096 \ --learning_rate 2e-4 \ --num_train_epochs 3.0 \ --max_samples 100000 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 4 \ --gradient_checkpointing True \ --gradient_checkpointing_kwargs '{"use_reentrant":false}' \ --lr_scheduler_type cosine \ --max_grad_norm 1.0 \ --logging_steps 5 \ --save_steps 100 \ --warmup_steps 0 \ --packing False \ --enable_thinking True \ --report_to tensorboard \ --quantization_bit 4 \ --output_dir "./checkpoints/ThaiLLM-30B-SFT-IQ" \ --overwrite_output_dir True \ --bf16 True \ --plot_loss True \ --trust_remote_code True \ --include_num_input_tokens_seen True \ --optim adamw_8bit \ --lora_rank 16 \ --lora_alpha 16 \ --upcast_layernorm True \ --ddp_timeout 180000000 \ --lora_dropout 0 \ --lora_target all \ --freeze_vision_tower True \ --freeze_multi_modal_projector True \ --use_unsloth False \ --seed 1234
Usage
The model is optimized for strict JSON output and context-only medical answers.
Expected Instruction Format
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "ThaiLLM-30B-SFT-IQ"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
)
system_prompt = """\
Answer in JSON format with citations only.
Use only the provided medical contexts to answer the question.
Include the fact IDs that support your answer in the following format:
{"answer": "<ANSWER_TEXT>", "citations": ["<FACT_ID1>", "<FACT_ID2>"]}
If unknown, respond with:
{"answer": "unknown", "citations": []}
"""
question = "การบาดเจ็บจากอุบัติเหตุที่ข้อต่อขากรรไกรสามารถทำให้อาการปวดร้าวไปที่ใดเมื่อเคี้ยวอาหาร?"
facts = """\
[1] การบาดเจ็บจากอุบัติเหตุที่ข้อต่อขากรรไกรสามารถทำให้ปวดร้าวไปที่หูเมื่อเคี้ยวอาหาร
[2] ข้อต่อขากรรไกรอักเสบ (TMJ) สามารถทำให้ปวดบริเวณใกล้ติ่งหูและร้าวไปที่หูเมื่อเคี้ยวอาหาร
"""
prompt = f"""{system_prompt}
Question:
{question}
Facts:
{facts}
"""
inputs = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
with torch.inference_mode():
outputs = model.generate(
inputs,
max_new_tokens=256,
do_sample=False,
temperature=0.0,
)
generated = outputs[0, inputs.shape[-1]:]
print(tokenizer.decode(generated, skip_special_tokens=True))