dejanseo commited on
Commit
3deb901
·
verified ·
1 Parent(s): b040c82

Upload 2 files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ assets/train-loss.png filter=lfs diff=lfs merge=lfs -text
assets/train-270m-ft.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Full fine-tune gemma-3-270m to reconstruct prompts from model outputs."""
3
+
4
+ from datasets import load_from_disk
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForSeq2Seq
6
+
7
+ # --- Configuration ---
8
+ MODEL_NAME = "google/gemma-3-270m"
9
+ DATASET_PATH = "tokenized-dataset-plain"
10
+ OUTPUT_DIR = "checkpoints-270m-ft"
11
+ MAX_SEQ_LENGTH = 2048
12
+
13
+ # --- Training ---
14
+ BATCH_SIZE = 2
15
+ GRAD_ACCUM = 8
16
+ LEARNING_RATE = 5e-5
17
+ NUM_EPOCHS = 1
18
+ WARMUP_STEPS = 100
19
+ LOGGING_STEPS = 1
20
+ SAVE_STEPS = 100
21
+
22
+
23
+ def main():
24
+ print("Loading pre-tokenized dataset...")
25
+ dataset = load_from_disk(DATASET_PATH)
26
+ print(f"Training examples: {len(dataset)}")
27
+
28
+ print(f"Loading model: {MODEL_NAME}")
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ MODEL_NAME,
31
+ torch_dtype="bfloat16",
32
+ device_map="auto",
33
+ )
34
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
35
+
36
+ trainer = Trainer(
37
+ model=model,
38
+ processing_class=tokenizer,
39
+ train_dataset=dataset,
40
+ data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True),
41
+ args=TrainingArguments(
42
+ output_dir=OUTPUT_DIR,
43
+ per_device_train_batch_size=BATCH_SIZE,
44
+ gradient_accumulation_steps=GRAD_ACCUM,
45
+ learning_rate=LEARNING_RATE,
46
+ num_train_epochs=NUM_EPOCHS,
47
+ warmup_steps=WARMUP_STEPS,
48
+ logging_steps=LOGGING_STEPS,
49
+ save_steps=SAVE_STEPS,
50
+ bf16=True,
51
+ seed=42,
52
+ report_to="wandb",
53
+ logging_strategy="steps",
54
+ gradient_checkpointing=True,
55
+ ),
56
+ )
57
+
58
+ print("Training...")
59
+ trainer.train()
60
+
61
+ print("Saving model...")
62
+ trainer.save_model(f"{OUTPUT_DIR}/final")
63
+ tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
64
+ print("Done.")
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
assets/train-loss.png ADDED

Git LFS Details

  • SHA256: 22e27d2504371f7a5c14d1cf816fe3a84e2fb548308a517496d36db303f44165
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB