Devansh0711 commited on
Commit
ce312c4
·
verified ·
1 Parent(s): 69bd982

Upload train_minigpt.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_minigpt.py +168 -0
train_minigpt.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import os
4
+ import json
5
+ import math
6
+ from tokenizers import ByteLevelBPETokenizer
7
+ from transformers import PreTrainedTokenizerFast
8
+ import tensorflow as tf
9
+ from tqdm import tqdm
10
+ # from tensorflow.keras import mixed_precision
11
+ #Devansh Sinha
12
+ from minigpt_transformer import MoEMiniGPT, MoEConfig
13
+ #Devansh Sinha
14
+ # Logging config
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # mixed_precision.set_global_policy('mixed_float16')
19
+ #Devansh Sinha
20
+ if __name__ == "__main__":
21
+ try:
22
+ tokenizer = PreTrainedTokenizerFast(
23
+ tokenizer_file="my-10k-bpe-tokenizer/tokenizer.json",
24
+ unk_token="<unk>",
25
+ pad_token="<pad>",
26
+ bos_token="<s>",
27
+ eos_token="</s>",
28
+ mask_token="<mask>",
29
+ )
30
+ #Devansh Sinha
31
+ config = MoEConfig(
32
+ vocab_size=10000,
33
+ max_seq_len=256,
34
+ seq_len=256,
35
+ embed_dim=512,
36
+ num_heads=8,
37
+ num_layers=8,
38
+ ffn_dim=2048,
39
+ dropout=0.1,
40
+ layer_norm_epsilon=1e-5,
41
+ use_rotary_embeddings=True,
42
+ learning_rate=2e-4,
43
+ batch_size=32,
44
+ num_experts=4,
45
+ top_k_experts=1,
46
+ use_moe_layers=[2, 4, 6]
47
+ )
48
+ #Devansh Sinha
49
+ logger.info("Initializing MoEMiniGPT model...")
50
+ model = MoEMiniGPT(config)
51
+
52
+ dummy_input = tf.ones((1, config.seq_len), dtype=tf.int32)
53
+ _ = model(dummy_input)
54
+
55
+ total_params = np.sum([np.prod(v.shape) for v in model.trainable_variables])
56
+ logger.info(f"Total model parameters: {total_params:,}")
57
+
58
+ # Load corpus and tokenize
59
+ corpus_path = "corpus.txt"
60
+ with open(corpus_path, "r", encoding="utf-8") as f:
61
+ lines = [line.strip() for line in f if line.strip()]
62
+ #Devansh Sinha
63
+ def encode_line(line):
64
+ tokens = tokenizer.encode(
65
+ line,
66
+ max_length=config.seq_len,
67
+ truncation=True,
68
+ padding='max_length'
69
+ )
70
+ return {"input_ids": np.array(tokens, dtype=np.int32)}
71
+
72
+ encoded = [encode_line(line) for line in lines]
73
+ #Devansh Sinha
74
+ train_dataset = tf.data.Dataset.from_generator(
75
+ lambda: (ex for ex in encoded),
76
+ output_signature={"input_ids": tf.TensorSpec(shape=(config.seq_len,), dtype=tf.int32)}
77
+ ).shuffle(2048).batch(config.batch_size)
78
+
79
+ logger.info(f"Training dataset created with {len(encoded)} examples.")
80
+ total_tokens = sum(len(tokenizer.encode(line)) for line in lines)
81
+ logger.info(f"Total number of tokens in corpus: {total_tokens}")
82
+
83
+ train_loss_metric = tf.keras.metrics.Mean(name='train_loss')
84
+ train_accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
85
+ #Devansh Sinha
86
+ optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate)
87
+ #Devansh Sinha
88
+ @tf.function
89
+ def train_step(batch):
90
+ input_ids = batch['input_ids']
91
+ targets = input_ids[:, 1:]
92
+ inputs = input_ids[:, :-1]
93
+ with tf.GradientTape() as tape:
94
+ logits, aux_losses = model(inputs, training=True)
95
+ loss = tf.keras.losses.sparse_categorical_crossentropy(
96
+ targets, logits, from_logits=True
97
+ )
98
+ pad_token_id = getattr(tokenizer, "pad_token_id", 0)
99
+ mask = tf.cast(tf.not_equal(targets, pad_token_id), tf.float32)
100
+ mask_sum = tf.reduce_sum(mask)
101
+ loss = tf.reduce_sum(loss * mask) / (mask_sum + 1e-8)
102
+ if aux_losses:
103
+ loss += tf.add_n([v for v in aux_losses.values()])
104
+ grads = tape.gradient(loss, model.trainable_variables)
105
+ grads, _ = tf.clip_by_global_norm(grads, 1.0)
106
+ optimizer.apply_gradients(zip(grads, model.trainable_variables))
107
+ train_loss_metric.update_state(loss)
108
+ train_accuracy_metric.update_state(targets, logits, sample_weight=mask)
109
+ return loss
110
+ #Devansh Sinha
111
+ logger.info("Starting training...")
112
+ epochs = 1
113
+ steps_per_epoch = math.ceil(len(encoded) / config.batch_size)
114
+ logger.info(f"Epochs: {epochs}, Steps per epoch: {steps_per_epoch}")
115
+ #Devansh Sinha
116
+ global_step = 0
117
+ for epoch in range(epochs):
118
+ train_loss_metric.reset_state()
119
+ train_accuracy_metric.reset_state()
120
+ epoch_losses = []
121
+ #Devansh Sinha
122
+ logger.info(f"Epoch {epoch+1}/{epochs} started.")
123
+ progbar = tqdm(train_dataset, total=steps_per_epoch, desc=f"Epoch {epoch+1}/{epochs}", ncols=100)
124
+ for step, batch in enumerate(progbar, 1):
125
+ global_step += 1
126
+ loss = train_step(batch)
127
+ epoch_losses.append(loss.numpy())
128
+ #Devansh Sinha
129
+ loss_val = train_loss_metric.result().numpy()
130
+ acc_val = train_accuracy_metric.result().numpy()
131
+ #Devansh Sinha
132
+ progbar.set_postfix({
133
+ "step": f"{step}/{steps_per_epoch}",
134
+ "loss": f"{loss_val:.4f}",
135
+ "acc": f"{acc_val:.4f}"
136
+ })
137
+ #Devansh Sinha
138
+ avg_loss = np.mean(epoch_losses)
139
+ perplexity = math.exp(avg_loss)
140
+ logger.info(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f} | Accuracy: {acc_val:.4f} | Perplexity: {perplexity:.2f}")
141
+ #Devansh Sinha
142
+ # Save model
143
+ save_dir = "trained_models"
144
+ os.makedirs(save_dir, exist_ok=True)
145
+ weights_path = os.path.join(save_dir, "moe_minigpt.weights.h5")
146
+ model.save_weights(weights_path)
147
+ logger.info(f"Model weights saved to {weights_path}")
148
+ #Devansh Sinha
149
+ config_path = os.path.join(save_dir, "moe_config.json")
150
+ with open(config_path, 'w') as f:
151
+ config_dict = {k: str(v) if isinstance(v, (list, type(None))) else v for k, v in vars(config).items()}
152
+ json.dump(config_dict, f, indent=2)
153
+ logger.info(f"Configuration saved to {config_path}")
154
+ #Devansh Sinha
155
+ # Optional chat interface
156
+ if hasattr(model, "generate_text"):
157
+ print("\n--- Chat with your model! Type 'quit' to exit. ---")
158
+ while True:
159
+ user_input = input("You: ")
160
+ if user_input.strip().lower() in ["quit", "exit"]:
161
+ print("Exiting chat.")
162
+ break
163
+ response = model.generate_text(user_input, max_length=50)
164
+ print("Model:", response)
165
+
166
+ except Exception as e:
167
+ logger.error(f"Error in main execution: {e}")
168
+ raise