WildnerveAI commited on
Commit
a829f5c
·
verified ·
1 Parent(s): 958a37b

Delete train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +0 -172
train_model.py DELETED
@@ -1,172 +0,0 @@
1
- import os
2
- import glob
3
- import time
4
- import torch
5
- import logging
6
- from torch import nn, optim
7
- from accelerate import Accelerator
8
- from torch.utils.data import DataLoader
9
- from typing import Optional, Dict, List, Any
10
- from datasets import load_dataset, concatenate_datasets, Features, Value
11
-
12
- # Import your core model; choose one implementation for training.
13
- from model_Custm import Wildnerve_tlm01
14
-
15
- logger = logging.getLogger(__name__)
16
- logging.basicConfig(level=logging.INFO)
17
-
18
- # New helper function to flatten JSON with hierarchical markers.
19
- def flatten_json(data):
20
- if isinstance(data, dict):
21
- parts = []
22
- for key, value in data.items():
23
- parts.append(f"{key}:{{{flatten_json(value)}}}")
24
- return " ".join(parts)
25
- elif isinstance(data, list):
26
- # Fixed the typo here: use "=" instead of "are"
27
- parts = [flatten_json(item) for item in data]
28
- return "[" + ", ".join(parts) + "]"
29
- else:
30
- return str(data)
31
-
32
- # New definition for convert_record, which uses flatten_json()
33
- def convert_record(record):
34
- raw = record.get("text", "")
35
- try:
36
- import json
37
- data = json.loads(raw)
38
- combined = flatten_json(data)
39
- return {"input": combined}
40
- except Exception:
41
- return {"input": raw}
42
-
43
- # Import tokenizer to convert text into tensor input
44
- from transformers import AutoTokenizer
45
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
46
-
47
- # Updated get_dataset() function to load from Hugging Face repo
48
- def get_dataset(split="train", use_hf_data=True, dataset_repo="EvolphTech/data"):
49
- if use_hf_data:
50
- try:
51
- logger.info(f"Loading dataset from Hugging Face: {dataset_repo}")
52
- dataset = load_dataset(dataset_repo, split=split)
53
-
54
- # If the dataset has a 'text' column, use it directly
55
- if 'text' in dataset.column_names:
56
- dataset = dataset.map(lambda x: {"input": x["text"]})
57
- else:
58
- logger.warning(f"No 'text' column found in {dataset_repo}. Using first text column found.")
59
- # Try to find a text column
60
- text_columns = [col for col in dataset.column_names if dataset.features[col].dtype == 'string']
61
- if text_columns:
62
- dataset = dataset.map(lambda x: {"input": x[text_columns[0]]})
63
- else:
64
- raise ValueError(f"No text columns found in {dataset_repo}")
65
-
66
- logger.info(f"Successfully loaded {len(dataset)} samples from Hugging Face")
67
- except Exception as e:
68
- logger.error(f"Failed to load dataset from Hugging Face: {e}")
69
- logger.info("Falling back to local dataset")
70
- return get_dataset(split=split, use_hf_data=False)
71
- else:
72
- # Fall back to the original local dataset loading logic
73
- data_dir = r"c:\Users\User\OneDrive\Documents\tlm\Wildnerve-tlm_HF"
74
- data_files = {
75
- "train": os.path.join(data_dir, "train.json"),
76
- "validation": os.path.join(data_dir, "validation.json")
77
- }
78
- features = Features({"text": Value("string")})
79
- dataset = load_dataset("json", data_files=data_files, features=features, split=split, download_mode="force_redownload")
80
- dataset = dataset.map(lambda x: {"input": x["text"]})
81
-
82
- class CustomDataset(torch.utils.data.Dataset):
83
- def __init__(self, data):
84
- self.data = data["input"]
85
- def __len__(self):
86
- return len(self.data)
87
- def __getitem__(self, idx):
88
- tokens = tokenizer(self.data[idx], truncation=True, padding="max_length", max_length=128, return_tensors="pt")
89
- return tokens["input_ids"].squeeze(0)
90
-
91
- return CustomDataset(dataset)
92
-
93
- def train(use_hf_data=True, dataset_repo="EvolphTech/data"):
94
- accelerator = Accelerator()
95
- # Use the training split now
96
- train_dataset = get_dataset("train", use_hf_data=use_hf_data, dataset_repo=dataset_repo)
97
- train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
98
-
99
- # Create your model (adjust constructor parameters as needed)
100
- model = Wildnerve_tlm01(
101
- vocab_size=30522,
102
- specialization="general",
103
- dataset_path="",
104
- model_name="bert-base-uncased",
105
- embedding_dim=256,
106
- num_heads=4,
107
- hidden_dim=256,
108
- num_layers=2,
109
- output_size=256,
110
- dropout=0.1,
111
- max_seq_length=128,
112
- pooling_mode="mean",
113
- use_pretrained_encoder=True
114
- )
115
- optimizer = optim.Adam(model.parameters(), lr=0.0001)
116
- # Replace MSELoss with CrossEntropyLoss.
117
- # Note: Assume model output logits are of shape [batch, seq_len, vocab_size]
118
- criterion = nn.CrossEntropyLoss()
119
-
120
- model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
121
-
122
- num_epochs = 50 # Change from 30 to 50
123
- for epoch in range(num_epochs):
124
- total_loss = 0.0
125
- for batch in train_loader:
126
- x = batch[..., :-1] # omit last token for inputs
127
- y = batch[..., 1:] # omit first token for labels
128
-
129
- optimizer.zero_grad()
130
- output = model(x) # shape is [batch_size, vocab_size]
131
-
132
- # Print shapes for debugging
133
- logger.info(f"Epoch {epoch+1}, Output shape: {output.shape}, Target shape: {y.shape}")
134
-
135
- # Since the model returns logits for just one position, take the first token from y
136
- # If your model really needs sequence data, you'd need a different handling strategy
137
- target = y[:, 0].long()
138
-
139
- # Use target directly - no reshape needed since it's already 1D
140
- loss = criterion(output, target)
141
-
142
- accelerator.backward(loss)
143
- optimizer.step()
144
- total_loss += loss.item()
145
- avg_loss = total_loss / len(train_loader)
146
- logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
147
- time.sleep(1) # simulate longer training
148
-
149
- # Save model weights as .pt then convert to .bin
150
- results_dir = r"c:\Users\User\OneDrive\Documents\tlm\results"
151
- os.makedirs(results_dir, exist_ok=True)
152
-
153
- pt_save_path = os.path.join(results_dir, "model_weights.pt")
154
- torch.save(model.state_dict(), pt_save_path)
155
- logger.info(f"Model weights saved to {pt_save_path}")
156
-
157
- # Convert .pt file to .bin (identical state_dict saved with .bin extension)
158
- bin_save_path = os.path.join(results_dir, "model_weights.bin")
159
- state_dict = torch.load(pt_save_path, weights_only=True)
160
- torch.save(state_dict, bin_save_path)
161
- logger.info(f"Model weights also saved as binary to {bin_save_path}")
162
-
163
- if __name__ == "__main__":
164
- import argparse
165
- parser = argparse.ArgumentParser(description="Train the model")
166
- parser.add_argument("--use_hf_data", action="store_true", help="Use data from Hugging Face repo")
167
- parser.add_argument("--dataset_repo", type=str, default="EvolphTech/data", help="Hugging Face dataset repository")
168
- parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")
169
-
170
- args = parser.parse_args()
171
-
172
- train(use_hf_data=args.use_hf_data, dataset_repo=args.dataset_repo)