S-Dreamer commited on
Commit
09d1245
·
verified ·
1 Parent(s): edaa68a

Update src/train.py

Browse files
Files changed (1) hide show
  1. src/train.py +43 -15
src/train.py CHANGED
@@ -1,5 +1,6 @@
 
1
  import os
2
- from typing import Optional
3
 
4
  import torch
5
  from datasets import load_dataset
@@ -14,10 +15,32 @@ from transformers import (
14
  from peft import LoraConfig, TaskType, get_peft_model
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def finetune_lora(
18
  base_model: str,
19
  dataset_id: str,
20
- text_column: str,
21
  output_dir: str,
22
  max_train_samples: int = 2000,
23
  max_steps: int = 100,
@@ -28,8 +51,11 @@ def finetune_lora(
28
  lora_dropout: float = 0.05,
29
  ) -> str:
30
  ds = load_dataset(dataset_id, split="train")
31
- if text_column not in ds.column_names:
32
- return f"ERROR: column '{text_column}' not found. Available: {ds.column_names}"
 
 
 
33
 
34
  if max_train_samples and max_train_samples > 0:
35
  ds = ds.select(range(min(len(ds), int(max_train_samples))))
@@ -39,15 +65,23 @@ def finetune_lora(
39
  tokenizer.pad_token = tokenizer.eos_token
40
 
41
  def tok(batch):
42
- return tokenizer(batch[text_column], truncation=True, max_length=256)
 
 
 
 
 
 
 
43
 
44
- tokenized = ds.map(tok, batched=True, remove_columns=ds.column_names)
45
 
46
  model = AutoModelForCausalLM.from_pretrained(base_model)
47
  model.config.pad_token_id = tokenizer.pad_token_id
48
 
49
- # LoRA target modules here are GPT-2-ish defaults.
50
- # If you swap to a non-GPT2 architecture, you may need to change target_modules.
 
51
  lora_cfg = LoraConfig(
52
  task_type=TaskType.CAUSAL_LM,
53
  r=int(lora_r),
@@ -72,13 +106,7 @@ def finetune_lora(
72
  fp16=fp16,
73
  )
74
 
75
- trainer = Trainer(
76
- model=model,
77
- args=args,
78
- train_dataset=tokenized,
79
- data_collator=collator,
80
- )
81
-
82
  trainer.train()
83
 
84
  adapter_dir = os.path.join(output_dir, "adapter")
 
1
+ # src/train.py
2
  import os
3
+ from typing import Dict, List
4
 
5
  import torch
6
  from datasets import load_dataset
 
15
  from peft import LoraConfig, TaskType, get_peft_model
16
 
17
 
18
+ def _format_as_chat(tokenizer, ex: Dict) -> str:
19
+ system = (ex.get("system") or "").strip()
20
+ user = (ex.get("user") or "").strip()
21
+ assistant = (ex.get("assistant") or "").strip()
22
+
23
+ # Preferred: model-native chat template (Llama/Qwen/Mistral Instruct, etc.)
24
+ if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
25
+ messages: List[Dict[str, str]] = []
26
+ if system:
27
+ messages.append({"role": "system", "content": system})
28
+ messages.append({"role": "user", "content": user})
29
+ messages.append({"role": "assistant", "content": assistant})
30
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
31
+
32
+ # Fallback: simple transcript
33
+ parts = []
34
+ if system:
35
+ parts.append(f"### System:\n{system}")
36
+ parts.append(f"### User:\n{user}")
37
+ parts.append(f"### Assistant:\n{assistant}")
38
+ return "\n\n".join(parts)
39
+
40
+
41
  def finetune_lora(
42
  base_model: str,
43
  dataset_id: str,
 
44
  output_dir: str,
45
  max_train_samples: int = 2000,
46
  max_steps: int = 100,
 
51
  lora_dropout: float = 0.05,
52
  ) -> str:
53
  ds = load_dataset(dataset_id, split="train")
54
+
55
+ needed = {"system", "user", "assistant"}
56
+ missing = needed.difference(set(ds.column_names))
57
+ if missing:
58
+ return f"ERROR: dataset missing columns {sorted(missing)}. Found: {ds.column_names}"
59
 
60
  if max_train_samples and max_train_samples > 0:
61
  ds = ds.select(range(min(len(ds), int(max_train_samples))))
 
65
  tokenizer.pad_token = tokenizer.eos_token
66
 
67
  def tok(batch):
68
+ texts = [_format_as_chat(tokenizer, ex) for ex in batch]
69
+ return tokenizer(texts, truncation=True, max_length=1024)
70
+
71
+ # map with batched=True expects a dict-of-lists; easiest is to build list of dicts per batch
72
+ def batched_map(batch):
73
+ # Convert dict-of-lists to list-of-dicts
74
+ exs = [dict(zip(batch.keys(), vals)) for vals in zip(*batch.values())]
75
+ return tok(exs)
76
 
77
+ tokenized = ds.map(batched_map, batched=True, remove_columns=ds.column_names)
78
 
79
  model = AutoModelForCausalLM.from_pretrained(base_model)
80
  model.config.pad_token_id = tokenizer.pad_token_id
81
 
82
+ # NOTE: target_modules depends on model architecture.
83
+ # GPT-2 uses c_attn/c_proj; Llama uses q_proj/k_proj/v_proj/o_proj; Qwen varies.
84
+ # Keep GPT-2 defaults here and change if you swap base_model.
85
  lora_cfg = LoraConfig(
86
  task_type=TaskType.CAUSAL_LM,
87
  r=int(lora_r),
 
106
  fp16=fp16,
107
  )
108
 
109
+ trainer = Trainer(model=model, args=args, train_dataset=tokenized, data_collator=collator)
 
 
 
 
 
 
110
  trainer.train()
111
 
112
  adapter_dir = os.path.join(output_dir, "adapter")