ThaiLLM-8B-SFT-IQ (Medical)
Model Description
ThaiLLM-8B-SFT-IQ (Medical) is a Thai-language large language model specialized for medical information query and citation-grounded question answering.
The model is fine-tuned from ThaiLLM-8B-SFT using supervised fine-tuning (SFT) with medical-domain data. It is designed for retrieval-augmented generation (RAG) workflows, where answers must be generated only from provided medical contexts and returned with explicit citations.
This model does not include a retriever and must be used with externally supplied medical documents.
Intended Use
Suitable for
- Medical RAG systems (Thai)
- Medical document question answering
- Citation-grounded medical QA
- Medical education and evaluation
Not suitable for
- Medical diagnosis or treatment decisions
- Patient-facing clinical advice
- Emergency or critical-care use
Model Details
- Base model: ThaiLLM-8B-SFT
- Domain: Medical
- Model type: Decoder-only causal language model
- Fine-tuning: Supervised fine-tuning (LoRA)
- Language: Thai
- Context length: 4096 tokens
- Precision: bfloat16
Performance (Medical Information Query Setting)
| Model | Response (BLEU) | Citations (Jaccard) |
|---|---|---|
| Qwen3-8B-Bas | 0.267 | 0.075 |
| ThaiLLM-8B-SFT | 0.406 | 0.133 |
| ThaiLLM-8B-SFT-IQ (Medical) | 0.4363 | 0.5485 |
Training
- Framework: Unsloth
- Method: LoRA-based supervised fine-tuning
| Hyperparameter | Value |
|---|---|
| Learning rate | 2e-4 |
| LoRA rank | 32 |
| LoRA alpha | 32 |
| Sequence length | 4096 |
| Epochs | 3 |
| Batch size | 8 |
Training data consists of Thai medical question–answer pairs with context grounding and citation supervision.
- Training Script
model_base, tokenizer = FastLanguageModel.from_pretrained( base_model, max_seq_length=4096, load_in_4bit=True, load_in_8bit=False, full_finetuning=False, device_map = "balanced", ) model = FastLanguageModel.get_peft_model( model_base, r=32, lora_alpha=32, lora_dropout=0, bias="none", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj" ], use_gradient_checkpointing="unsloth", random_state=seed, use_rslora=False, loftq_config=None, ) trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=train_ds, dataset_num_proc=4, args=SFTConfig( dataset_text_field="text", per_device_train_batch_size=2, gradient_accumulation_steps=8, gradient_checkpointing = True, warmup_ratio=0.1, warmup_steps=5, num_train_epochs=3, learning_rate=2e-4, logging_steps=1, optim="adamw_8bit", weight_decay=0.01, lr_scheduler_type="cosine", fp16=not torch.cuda.is_bf16_supported(), bf16=torch.cuda.is_bf16_supported(), seed=args.seed, report_to=["tensorboard"], output_dir=f"{paths['log_path']}", logging_dir=f"{paths['log_path']}", ), ) trainer.train()
Usage
The model is optimized for strict JSON output and context-only medical answers.
Expected Instruction Format
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "ThaiLLM-8B-SFT-IQ"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
)
system_prompt = """\
Answer in JSON format with citations only.
Use only the provided medical contexts to answer the question.
Include the fact IDs that support your answer in the following format:
{"answer": "<ANSWER_TEXT>", "citations": ["<FACT_ID1>", "<FACT_ID2>"]}
If unknown, respond with:
{"answer": "unknown", "citations": []}
"""
question = "การบาดเจ็บจากอุบัติเหตุที่ข้อต่อขากรรไกรสามารถทำให้อาการปวดร้าวไปที่ใดเมื่อเคี้ยวอาหาร?"
facts = """\
[1] การบาดเจ็บจากอุบัติเหตุที่ข้อต่อขากรรไกรสามารถทำให้ปวดร้าวไปที่หูเมื่อเคี้ยวอาหาร
[2] ข้อต่อขากรรไกรอักเสบ (TMJ) สามารถทำให้ปวดบริเวณใกล้ติ่งหูและร้าวไปที่หูเมื่อเคี้ยวอาหาร
"""
prompt = f"""{system_prompt}
Question:
{question}
Facts:
{facts}
"""
inputs = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
with torch.inference_mode():
outputs = model.generate(
inputs,
max_new_tokens=256,
do_sample=False,
temperature=0.0,
)
generated = outputs[0, inputs.shape[-1]:]
print(tokenizer.decode(generated, skip_special_tokens=True))
- Downloads last month
- 45