akiliaiafrica commited on
Commit
bee2af9
·
verified ·
1 Parent(s): 41d44e7

Upload train_jafari_chatbot.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_jafari_chatbot.py +120 -0
train_jafari_chatbot.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "transformers>=4.51.0", "datasets", "accelerate", "bitsandbytes"]
3
+ # ///
4
+
5
+ import sys
6
+ import traceback
7
+
8
+ print("="*80)
9
+ print("PRODUCTION TRAINING - Jafari Credit WhatsApp Chatbot")
10
+ print("="*80)
11
+ print(f"Python version: {sys.version}")
12
+
13
+ try:
14
+ print("\n[1/7] Importing libraries...")
15
+ from datasets import load_dataset
16
+ from peft import LoraConfig
17
+ from trl import SFTTrainer, SFTConfig
18
+ from transformers import AutoTokenizer
19
+ import trackio
20
+ print("✓ All imports successful")
21
+
22
+ # Load the dataset from Hub
23
+ print("\n[2/7] Loading dataset...")
24
+ dataset = load_dataset("akiliaiafrica/jafari-credit-whatsapp-chatbot", split="train")
25
+ print(f"✓ Dataset loaded: {len(dataset)} conversations")
26
+ print(f"Sample keys: {list(dataset[0].keys())}")
27
+
28
+ # Create train/eval split
29
+ print("\n[3/7] Creating train/eval split...")
30
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
31
+ print(f"✓ Train: {len(dataset_split['train'])} | Eval: {len(dataset_split['test'])}")
32
+
33
+ # Load tokenizer
34
+ print("\n[4/7] Loading tokenizer...")
35
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
36
+ if tokenizer.pad_token is None:
37
+ tokenizer.pad_token = tokenizer.eos_token
38
+ print("✓ Tokenizer loaded")
39
+
40
+ # Preprocess dataset
41
+ print("\n[5/7] Preprocessing dataset with chat template...")
42
+
43
+ def convert_to_text(example):
44
+ text = tokenizer.apply_chat_template(
45
+ example["messages"],
46
+ tokenize=False,
47
+ add_generation_prompt=False
48
+ )
49
+ return {"text": text}
50
+
51
+ train_dataset = dataset_split["train"].map(convert_to_text, remove_columns=dataset_split["train"].column_names)
52
+ eval_dataset = dataset_split["test"].map(convert_to_text, remove_columns=dataset_split["test"].column_names)
53
+ print(f"✓ Train={len(train_dataset)}, Eval={len(eval_dataset)}")
54
+ print(f"Sample text length: {len(train_dataset[0]['text'])} chars")
55
+
56
+ # LoRA config
57
+ print("\n[6/7] Configuring LoRA and trainer...")
58
+ peft_config = LoraConfig(
59
+ r=16,
60
+ lora_alpha=32,
61
+ lora_dropout=0.05,
62
+ bias="none",
63
+ task_type="CAUSAL_LM",
64
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
65
+ )
66
+
67
+ trainer = SFTTrainer(
68
+ model="Qwen/Qwen3-0.6B",
69
+ train_dataset=train_dataset,
70
+ eval_dataset=eval_dataset,
71
+ peft_config=peft_config,
72
+ args=SFTConfig(
73
+ output_dir="jafari-chatbot-qwen3",
74
+ num_train_epochs=3,
75
+ per_device_train_batch_size=2,
76
+ per_device_eval_batch_size=2,
77
+ gradient_accumulation_steps=8,
78
+ gradient_checkpointing=True,
79
+ learning_rate=2e-4,
80
+ lr_scheduler_type="cosine",
81
+ warmup_ratio=0.1,
82
+ eval_strategy="steps",
83
+ eval_steps=50,
84
+ save_strategy="steps",
85
+ save_steps=50,
86
+ save_total_limit=3,
87
+ push_to_hub=True,
88
+ hub_model_id="akiliaiafrica/jafari-chatbot-qwen3-0.6b",
89
+ hub_private_repo=True,
90
+ hub_strategy="every_save",
91
+ logging_steps=5,
92
+ report_to="trackio",
93
+ run_name="jafari-credit-whatsapp-sft",
94
+ bf16=True,
95
+ optim="adamw_8bit",
96
+ max_grad_norm=1.0,
97
+ )
98
+ )
99
+ print("✓ Trainer initialized")
100
+
101
+ # Train
102
+ print("\n[7/7] Starting training...")
103
+ print("="*80)
104
+ sys.stdout.flush()
105
+
106
+ trainer.train()
107
+
108
+ print("\n" + "="*80)
109
+ print("Training completed! Pushing model to Hub...")
110
+ trainer.push_to_hub()
111
+
112
+ print("\n✓ SUCCESS: Model pushed to akiliaiafrica/jafari-chatbot-qwen3-0.6b")
113
+ print("="*80)
114
+
115
+ except Exception as e:
116
+ print(f"\n\n{'='*80}")
117
+ print(f"ERROR: {type(e).__name__}: {e}")
118
+ print("="*80)
119
+ traceback.print_exc()
120
+ sys.exit(1)