Percy3822 commited on
Commit
078d71d
Β·
verified Β·
1 Parent(s): 55581b0

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +27 -50
train.py CHANGED
@@ -1,73 +1,50 @@
 
1
  from datasets import load_dataset
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
3
- import os
4
- import sys
5
 
6
- print("πŸ”₯ Python AI training script started!", file=sys.stderr)
 
 
 
7
 
8
- DATASET_PATH = "python_ai_dataset.jsonl"
9
- MODEL_ID = "bigcode/starcoderbase-7b"
10
- OUTPUT_DIR = "train_output"
11
 
12
- # === Step 1: Check dataset ===
13
- if not os.path.exists(DATASET_PATH):
14
- print(f"❌ Dataset file not found: {DATASET_PATH}", file=sys.stderr)
15
- sys.exit(1)
16
 
17
- # === Step 2: Load dataset (first 10 samples for fast test) ===
18
- try:
19
- dataset = load_dataset("json", data_files=DATASET_PATH, split="train[:10]") # Load only 10 samples for testing
20
- except Exception as e:
21
- print(f"❌ Failed to load dataset: {e}", file=sys.stderr)
22
- sys.exit(1)
23
 
24
- # === Step 3: Load tokenizer and model ===
25
- try:
26
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
27
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
28
- except Exception as e:
29
- print(f"❌ Failed to load model/tokenizer: {e}", file=sys.stderr)
30
- sys.exit(1)
31
-
32
- # === Step 4: Preprocess data ===
33
- def tokenize(example):
34
- return tokenizer(example["prompt"] + "\n" + example["completion"], truncation=True, max_length=512)
35
-
36
- try:
37
- tokenized_dataset = dataset.map(tokenize, remove_columns=["prompt", "completion"])
38
- except Exception as e:
39
- print(f"❌ Tokenization error: {e}", file=sys.stderr)
40
- sys.exit(1)
41
 
 
42
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
43
 
44
- # === Step 5: Training config ===
45
  training_args = TrainingArguments(
46
- output_dir=OUTPUT_DIR,
47
- overwrite_output_dir=True,
48
- per_device_train_batch_size=1,
49
  num_train_epochs=1,
50
- logging_dir="./logs",
51
  logging_steps=1,
52
- save_strategy="epoch",
53
  save_total_limit=1,
54
- fp16=False,
55
- report_to="none"
56
  )
57
 
58
- # === Step 6: Train the model ===
59
  trainer = Trainer(
60
  model=model,
61
  args=training_args,
62
- train_dataset=tokenized_dataset,
63
  tokenizer=tokenizer,
64
- data_collator=data_collator
65
  )
66
 
67
- print("πŸš€ Starting training on 10 samples...", file=sys.stderr)
68
  trainer.train()
69
-
70
- # === Step 7: Save model ===
71
- trainer.save_model(OUTPUT_DIR)
72
- tokenizer.save_pretrained(OUTPUT_DIR)
73
- print("βœ… Training finished and model saved!", file=sys.stderr)
 
1
+ import argparse
2
  from datasets import load_dataset
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
 
 
4
 
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument("--dataset", required=True)
7
+ parser.add_argument("--output", default="trained_model")
8
+ args = parser.parse_args()
9
 
10
+ print("πŸ“Š Loading dataset...")
11
+ dataset = load_dataset("json", data_files=args.dataset, split="train")
 
12
 
13
+ print("🧠 Loading model and tokenizer...")
14
+ tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
15
+ tokenizer.pad_token = tokenizer.eos_token
16
+ model = AutoModelForCausalLM.from_pretrained("distilgpt2")
17
 
18
+ # βœ… Clean, batch-safe tokenize
19
+ def tokenize(batch):
20
+ full_texts = [str(p) + str(c) for p, c in zip(batch["prompt"], batch["completion"])]
21
+ return tokenizer(full_texts, padding="max_length", truncation=True, max_length=256)
 
 
22
 
23
+ print("πŸ” Tokenizing...")
24
+ tokenized = dataset.map(tokenize, batched=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ print("πŸ“¦ Setting up trainer...")
27
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
28
 
 
29
  training_args = TrainingArguments(
30
+ output_dir=args.output,
31
+ per_device_train_batch_size=2,
 
32
  num_train_epochs=1,
 
33
  logging_steps=1,
34
+ save_steps=5,
35
  save_total_limit=1,
36
+ report_to=[]
 
37
  )
38
 
 
39
  trainer = Trainer(
40
  model=model,
41
  args=training_args,
42
+ train_dataset=tokenized,
43
  tokenizer=tokenizer,
44
+ data_collator=data_collator,
45
  )
46
 
47
+ print("πŸš€ Starting training...")
48
  trainer.train()
49
+ trainer.save_model(args.output)
50
+ print("βœ… Done.")