chrisvoncsefalvay commited on
Commit
a44915a
·
verified ·
1 Parent(s): 4bbe78f

Upload train_smol_discharge.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_smol_discharge.py +93 -0
train_smol_discharge.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # "bitsandbytes",
9
+ # ]
10
+ # ///
11
+
12
+ from datasets import load_dataset
13
+ from peft import LoraConfig
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
+ from trl import SFTTrainer, SFTConfig
16
+
17
+ print("Loading tokenizer...")
18
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B-Base")
19
+
20
+ CHAT_TEMPLATE = "{% for message in messages %}{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
21
+
22
+ tokenizer.chat_template = CHAT_TEMPLATE
23
+ special_tokens = {"additional_special_tokens": ["<|im_start|>", "<|im_end|>"]}
24
+ tokenizer.add_special_tokens(special_tokens)
25
+ if tokenizer.pad_token is None:
26
+ tokenizer.pad_token = tokenizer.eos_token
27
+
28
+ print("Loading model...")
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ "HuggingFaceTB/SmolLM3-3B-Base",
31
+ torch_dtype="auto",
32
+ device_map="auto",
33
+ )
34
+ model.resize_token_embeddings(len(tokenizer))
35
+
36
+ print("Loading dataset...")
37
+ train_dataset = load_dataset("chrisvoncsefalvay/smol-discharge-notes-sft", split="train")
38
+ eval_dataset = load_dataset("chrisvoncsefalvay/smol-discharge-notes-sft", split="validation")
39
+ print(f"Train: {len(train_dataset)} examples")
40
+ print(f"Eval: {len(eval_dataset)} examples")
41
+
42
+ config = SFTConfig(
43
+ output_dir="smollm3-discharge-notes-sft",
44
+ push_to_hub=True,
45
+ hub_model_id="chrisvoncsefalvay/smollm3-discharge-notes-sft",
46
+ hub_strategy="every_save",
47
+ num_train_epochs=3,
48
+ per_device_train_batch_size=8,
49
+ per_device_eval_batch_size=4,
50
+ gradient_accumulation_steps=2,
51
+ learning_rate=2e-5,
52
+ max_length=2048,
53
+ logging_steps=10,
54
+ save_strategy="steps",
55
+ save_steps=50,
56
+ save_total_limit=2,
57
+ eval_strategy="steps",
58
+ eval_steps=50,
59
+ warmup_ratio=0.1,
60
+ lr_scheduler_type="cosine",
61
+ gradient_checkpointing=True,
62
+ bf16=True,
63
+ report_to="trackio",
64
+ project="clinical-action-processing",
65
+ run_name="smollm3-3b-discharge-sft-a100",
66
+ )
67
+
68
+ peft_config = LoraConfig(
69
+ r=16,
70
+ lora_alpha=32,
71
+ lora_dropout=0.05,
72
+ bias="none",
73
+ task_type="CAUSAL_LM",
74
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
75
+ )
76
+
77
+ print("Initializing trainer...")
78
+ trainer = SFTTrainer(
79
+ model=model,
80
+ processing_class=tokenizer,
81
+ train_dataset=train_dataset,
82
+ eval_dataset=eval_dataset,
83
+ args=config,
84
+ peft_config=peft_config,
85
+ )
86
+
87
+ print("Starting training...")
88
+ trainer.train()
89
+
90
+ print("Pushing to Hub...")
91
+ trainer.push_to_hub()
92
+
93
+ print("Complete!")