ThaiLLM-8B-SFT-IQ (Medical)

Model Description

ThaiLLM-8B-SFT-IQ (Medical) is a Thai-language large language model specialized for medical information query and citation-grounded question answering.

The model is fine-tuned from ThaiLLM-8B-SFT using supervised fine-tuning (SFT) with medical-domain data. It is designed for retrieval-augmented generation (RAG) workflows, where answers must be generated only from provided medical contexts and returned with explicit citations.

This model does not include a retriever and must be used with externally supplied medical documents.


Intended Use

Suitable for

  • Medical RAG systems (Thai)
  • Medical document question answering
  • Citation-grounded medical QA
  • Medical education and evaluation

Not suitable for

  • Medical diagnosis or treatment decisions
  • Patient-facing clinical advice
  • Emergency or critical-care use

Model Details

  • Base model: ThaiLLM-8B-SFT
  • Domain: Medical
  • Model type: Decoder-only causal language model
  • Fine-tuning: Supervised fine-tuning (LoRA)
  • Language: Thai
  • Context length: 4096 tokens
  • Precision: bfloat16

Performance (Medical Information Query Setting)

Model Response (BLEU) Citations (Jaccard)
Qwen3-8B-Bas 0.267 0.075
ThaiLLM-8B-SFT 0.406 0.133
ThaiLLM-8B-SFT-IQ (Medical) 0.4363 0.5485

Training

  • Framework: Unsloth
  • Method: LoRA-based supervised fine-tuning
Hyperparameter Value
Learning rate 2e-4
LoRA rank 32
LoRA alpha 32
Sequence length 4096
Epochs 3
Batch size 8

Training data consists of Thai medical question–answer pairs with context grounding and citation supervision.

  • Training Script
  • model_base, tokenizer = FastLanguageModel.from_pretrained(
          base_model,
          max_seq_length=4096,
          load_in_4bit=True,
          load_in_8bit=False,
          full_finetuning=False,
          device_map = "balanced",
      )
    
     model = FastLanguageModel.get_peft_model(
          model_base,
          r=32,
          lora_alpha=32, 
          lora_dropout=0,
          bias="none",
          target_modules=[
              "q_proj", "k_proj", "v_proj", "o_proj",
              "gate_proj", "up_proj", "down_proj"
          ],
          use_gradient_checkpointing="unsloth",
          random_state=seed,
          use_rslora=False,
          loftq_config=None,
      )
    
    trainer = SFTTrainer(
          model=model,
          tokenizer=tokenizer,
          train_dataset=train_ds,
          dataset_num_proc=4,
          args=SFTConfig(
              dataset_text_field="text",
              per_device_train_batch_size=2,
              gradient_accumulation_steps=8,
              gradient_checkpointing = True,
              warmup_ratio=0.1,
              warmup_steps=5,
              num_train_epochs=3,     
              learning_rate=2e-4,
              logging_steps=1,
              optim="adamw_8bit",
              weight_decay=0.01,
              lr_scheduler_type="cosine",
              fp16=not torch.cuda.is_bf16_supported(),
              bf16=torch.cuda.is_bf16_supported(),
              seed=args.seed,
              report_to=["tensorboard"],
              output_dir=f"{paths['log_path']}",
              logging_dir=f"{paths['log_path']}",
          ),
      )
    
       trainer.train()
    

Usage

The model is optimized for strict JSON output and context-only medical answers.

Expected Instruction Format

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "ThaiLLM-8B-SFT-IQ"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

system_prompt = """\
Answer in JSON format with citations only.
Use only the provided medical contexts to answer the question.
Include the fact IDs that support your answer in the following format:
{"answer": "<ANSWER_TEXT>", "citations": ["<FACT_ID1>", "<FACT_ID2>"]}

If unknown, respond with:
{"answer": "unknown", "citations": []}
"""

question = "การบาดเจ็บจากอุบัติเหตุที่ข้อต่อขากรรไกรสามารถทำให้อาการปวดร้าวไปที่ใดเมื่อเคี้ยวอาหาร?"

facts = """\
[1] การบาดเจ็บจากอุบัติเหตุที่ข้อต่อขากรรไกรสามารถทำให้ปวดร้าวไปที่หูเมื่อเคี้ยวอาหาร
[2] ข้อต่อขากรรไกรอักเสบ (TMJ) สามารถทำให้ปวดบริเวณใกล้ติ่งหูและร้าวไปที่หูเมื่อเคี้ยวอาหาร
"""

prompt = f"""{system_prompt}

Question:
{question}

Facts:
{facts}
"""

inputs = tokenizer.apply_chat_template(
    [{"role": "user", "content": prompt}],
    add_generation_prompt=True,
    return_tensors="pt",
).to(model.device)

with torch.inference_mode():
    outputs = model.generate(
        inputs,
        max_new_tokens=256,
        do_sample=False,
        temperature=0.0,
    )

generated = outputs[0, inputs.shape[-1]:]
print(tokenizer.decode(generated, skip_special_tokens=True))
Downloads last month
45
Safetensors
Model size
8B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for ThaiLLM/ThaiLLM-8B-SFT-IQ

Quantizations
2 models