Musombi commited on
Commit
264e4da
·
verified ·
1 Parent(s): 90a32b4

Update training/train_language.py

Browse files
Files changed (1) hide show
  1. training/train_language.py +16 -42
training/train_language.py CHANGED
@@ -12,22 +12,27 @@ from language.intent import IntentClassifier
12
  from datasets import load_dataset
13
 
14
  # ================================
15
- # LOAD DATA FROM HUGGING FACE
16
  # ================================
17
- print("[INFO] Loading dataset from Hugging Face...")
18
-
19
- load_dataset("clinc_oos", "plus")
 
 
20
 
 
21
 
 
 
 
 
22
  hf_dataset = load_dataset("clinc_oos", "plus")
23
 
24
  texts = hf_dataset["train"]["text"]
25
  labels = hf_dataset["train"]["intent"]
26
-
27
  intent_labels = sorted(list(set(labels)))
28
 
29
- print(f"[INFO] Loaded {len(texts)} samples from Hugging Face")
30
-
31
 
32
  # ================================
33
  # DATASET
@@ -43,9 +48,7 @@ class LanguageDataset(Dataset):
43
  return len(self.texts)
44
 
45
  def __getitem__(self, idx):
46
- token_ids = self.tokenizer.encode(self.texts[idx])
47
- # Truncate to max_seq_len
48
- token_ids = token_ids[:self.max_seq_len]
49
  token_ids = torch.tensor(token_ids, dtype=torch.long)
50
  label = torch.tensor(self.labels[idx], dtype=torch.long)
51
  return token_ids, label
@@ -59,34 +62,11 @@ def collate_fn(batch, tokenizer):
59
  padded = []
60
  for t in token_ids:
61
  pad_len = max_len - len(t)
62
- padded.append(
63
- torch.cat([t, torch.full((pad_len,), tokenizer.vocab[tokenizer.PAD_TOKEN], dtype=torch.long)])
64
- )
65
  return torch.stack(padded), torch.tensor(labels)
66
 
67
  # ================================
68
- # LOAD DATA
69
- # ================================
70
- def load_data(path):
71
- texts, labels = [], []
72
- intent_labels = set()
73
- if not os.path.exists(path):
74
- raise FileNotFoundError(f"Dataset file not found: {path}")
75
-
76
- with open(path, "r", encoding="utf-8") as f:
77
- for line in f:
78
- line = line.strip()
79
- if line:
80
- text, intent = line.split("\t")
81
- texts.append(text)
82
- labels.append(intent)
83
- intent_labels.add(intent)
84
- return texts, labels, sorted(list(intent_labels))
85
-
86
- texts, labels, intent_labels = load_data("musombi/intent_datasets")
87
-
88
- # ================================
89
- # TOKENIZER
90
  # ================================
91
  tokenizer = SimpleTokenizer()
92
  tokenizer.build_vocab(texts)
@@ -95,8 +75,6 @@ tokenizer.freeze_vocab()
95
  dataset = LanguageDataset(texts, labels, tokenizer, intent_labels, max_seq_len=MAX_SEQ_LEN)
96
  loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda batch: collate_fn(batch, tokenizer))
97
 
98
- print(f"[INFO] Loaded {len(dataset)} samples with {len(intent_labels)} intents")
99
-
100
  # ================================
101
  # MODEL
102
  # ================================
@@ -107,10 +85,7 @@ classifier = IntentClassifier(input_dim=encoder.projection.out_features, intent_
107
  embedder, encoder, classifier = embedder.to(DEVICE), encoder.to(DEVICE), classifier.to(DEVICE)
108
 
109
  criterion = nn.CrossEntropyLoss()
110
- optimizer = optim.Adam(
111
- list(embedder.parameters()) + list(encoder.parameters()) + list(classifier.parameters()),
112
- lr=LEARNING_RATE
113
- )
114
 
115
  # ================================
116
  # TRAINING LOOP
@@ -121,7 +96,6 @@ def train():
121
  total_loss = 0
122
  for token_ids, labels_batch in loader:
123
  token_ids, labels_batch = token_ids.to(DEVICE), labels_batch.to(DEVICE)
124
-
125
  embeddings = embedder(token_ids)
126
  attention_mask = (token_ids != tokenizer.vocab[tokenizer.PAD_TOKEN]).long()
127
  sentence_vec = encoder(embeddings, attention_mask=attention_mask)
 
12
  from datasets import load_dataset
13
 
14
  # ================================
15
+ # CONFIG
16
  # ================================
17
+ ARTIFACTS_DIR = "artifacts"
18
+ BATCH_SIZE = 16
19
+ EPOCHS = 10
20
+ LEARNING_RATE = 3e-4
21
+ MAX_SEQ_LEN = 64
22
 
23
+ os.makedirs(ARTIFACTS_DIR, exist_ok=True)
24
 
25
+ # ================================
26
+ # LOAD DATA FROM HUGGING FACE
27
+ # ================================
28
+ print("[INFO] Loading dataset from Hugging Face...")
29
  hf_dataset = load_dataset("clinc_oos", "plus")
30
 
31
  texts = hf_dataset["train"]["text"]
32
  labels = hf_dataset["train"]["intent"]
 
33
  intent_labels = sorted(list(set(labels)))
34
 
35
+ print(f"[INFO] Loaded {len(texts)} samples with {len(intent_labels)} intents")
 
36
 
37
  # ================================
38
  # DATASET
 
48
  return len(self.texts)
49
 
50
  def __getitem__(self, idx):
51
+ token_ids = self.tokenizer.encode(self.texts[idx])[:self.max_seq_len]
 
 
52
  token_ids = torch.tensor(token_ids, dtype=torch.long)
53
  label = torch.tensor(self.labels[idx], dtype=torch.long)
54
  return token_ids, label
 
62
  padded = []
63
  for t in token_ids:
64
  pad_len = max_len - len(t)
65
+ padded.append(torch.cat([t, torch.full((pad_len,), tokenizer.vocab[tokenizer.PAD_TOKEN], dtype=torch.long)]))
 
 
66
  return torch.stack(padded), torch.tensor(labels)
67
 
68
  # ================================
69
+ # TOKENIZER AND DATALOADER
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # ================================
71
  tokenizer = SimpleTokenizer()
72
  tokenizer.build_vocab(texts)
 
75
  dataset = LanguageDataset(texts, labels, tokenizer, intent_labels, max_seq_len=MAX_SEQ_LEN)
76
  loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda batch: collate_fn(batch, tokenizer))
77
 
 
 
78
  # ================================
79
  # MODEL
80
  # ================================
 
85
  embedder, encoder, classifier = embedder.to(DEVICE), encoder.to(DEVICE), classifier.to(DEVICE)
86
 
87
  criterion = nn.CrossEntropyLoss()
88
+ optimizer = optim.Adam(list(embedder.parameters()) + list(encoder.parameters()) + list(classifier.parameters()), lr=LEARNING_RATE)
 
 
 
89
 
90
  # ================================
91
  # TRAINING LOOP
 
96
  total_loss = 0
97
  for token_ids, labels_batch in loader:
98
  token_ids, labels_batch = token_ids.to(DEVICE), labels_batch.to(DEVICE)
 
99
  embeddings = embedder(token_ids)
100
  attention_mask = (token_ids != tokenizer.vocab[tokenizer.PAD_TOKEN]).long()
101
  sentence_vec = encoder(embeddings, attention_mask=attention_mask)