vedaco commited on
Commit
3304074
·
verified ·
1 Parent(s): a7b3073

Create akasha/utils.py

Browse files
Files changed (1) hide show
  1. akasha/utils.py +127 -0
akasha/utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for AKASHA.
3
+ """
4
+
5
+ import tensorflow as tf
6
+ import numpy as np
7
+ import json
8
+ import os
9
+
10
+
11
+ def load_config(config_path="config.json"):
12
+ with open(config_path, "r") as f:
13
+ config = json.load(f)
14
+ return config
15
+
16
+
17
+ def create_default_config():
18
+ return {
19
+ "model": {
20
+ "name": "AKASHA",
21
+ "version": "1.0",
22
+ "tokenizer": {
23
+ "image_size": 256,
24
+ "patch_size": 8,
25
+ "num_tokens": 1024,
26
+ "codebook_dim": 256,
27
+ "encoder_hidden_dims": [64, 128, 256, 512],
28
+ "decoder_hidden_dims": [512, 256, 128, 64],
29
+ "commitment_cost": 0.25,
30
+ "num_residual_blocks": 2,
31
+ },
32
+ "transformer": {
33
+ "num_layers": 24,
34
+ "d_model": 1024,
35
+ "num_heads": 16,
36
+ "d_ff": 4096,
37
+ "dropout_rate": 0.1,
38
+ "max_sequence_length": 1024,
39
+ "vocab_size": 1024,
40
+ "use_rotary_embeddings": True,
41
+ },
42
+ "generation": {
43
+ "temperature": 0.9,
44
+ "top_k": 100,
45
+ "top_p": 0.95,
46
+ },
47
+ },
48
+ "training": {
49
+ "batch_size": 32,
50
+ "learning_rate": 3e-4,
51
+ "warmup_steps": 4000,
52
+ "total_steps": 500000,
53
+ "weight_decay": 0.01,
54
+ "gradient_clip_norm": 1.0,
55
+ "mixed_precision": True,
56
+ "stage1": {"epochs": 100, "learning_rate": 1e-4, "batch_size": 64},
57
+ "stage2": {"epochs": 200, "learning_rate": 3e-4, "batch_size": 32},
58
+ },
59
+ "data": {"dataset": "imagenet", "image_size": 256, "augmentation": True},
60
+ "huggingface": {"repo_id": "vedaco/AKASHA", "space_sdk": "gradio"},
61
+ }
62
+
63
+
64
+ class CosineDecayWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
65
+ def __init__(self, base_lr, warmup_steps, total_steps, min_lr=1e-6):
66
+ super().__init__()
67
+ self.base_lr = base_lr
68
+ self.warmup_steps = warmup_steps
69
+ self.total_steps = total_steps
70
+ self.min_lr = min_lr
71
+
72
+ def __call__(self, step):
73
+ step = tf.cast(step, tf.float32)
74
+ warmup_lr = self.base_lr * (step / self.warmup_steps)
75
+ progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
76
+ progress = tf.clip_by_value(progress, 0.0, 1.0)
77
+ cosine_lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (
78
+ 1.0 + tf.cos(np.pi * progress)
79
+ )
80
+ return tf.where(step < self.warmup_steps, warmup_lr, cosine_lr)
81
+
82
+ def get_config(self):
83
+ return {
84
+ "base_lr": self.base_lr,
85
+ "warmup_steps": self.warmup_steps,
86
+ "total_steps": self.total_steps,
87
+ "min_lr": self.min_lr,
88
+ }
89
+
90
+
91
+ def save_images_grid(images, filepath, grid_size=None):
92
+ from PIL import Image as PILImage
93
+
94
+ if isinstance(images, tf.Tensor):
95
+ images = images.numpy()
96
+ n = images.shape[0]
97
+ if grid_size is None:
98
+ grid_size = int(np.ceil(np.sqrt(n)))
99
+ h, w = images.shape[1], images.shape[2]
100
+ grid = np.zeros((grid_size * h, grid_size * w, 3), dtype=np.uint8)
101
+ for i in range(min(n, grid_size * grid_size)):
102
+ row, col = i // grid_size, i % grid_size
103
+ img = (images[i] * 255).clip(0, 255).astype(np.uint8)
104
+ grid[row * h : (row + 1) * h, col * w : (col + 1) * w] = img
105
+ PILImage.fromarray(grid).save(filepath)
106
+ return filepath
107
+
108
+
109
+ def count_parameters(model):
110
+ return sum(np.prod(v.shape) for v in model.trainable_variables)
111
+
112
+
113
+ def get_model_summary(config):
114
+ tok = config["model"]["tokenizer"]
115
+ trans = config["model"]["transformer"]
116
+ grid_size = tok["image_size"] // tok["patch_size"]
117
+ seq_len = grid_size * grid_size
118
+ print("=" * 60)
119
+ print(" AKASHA Model Configuration")
120
+ print("=" * 60)
121
+ print(f" Image Size: {tok['image_size']}x{tok['image_size']}")
122
+ print(f" Patch Size: {tok['patch_size']}x{tok['patch_size']}")
123
+ print(f" Grid Size: {grid_size}x{grid_size}")
124
+ print(f" Sequence Length: {seq_len} tokens")
125
+ print(f" Codebook Size: {tok['num_tokens']}")
126
+ print(f" Transformer: {trans['num_layers']}L / {trans['d_model']}D / {trans['num_heads']}H")
127
+ print("=" * 60)