AKASHA / akasha /utils.py
vedaco's picture
Create akasha/utils.py
3304074 verified
"""
Utility functions for AKASHA.
"""
import tensorflow as tf
import numpy as np
import json
import os
def load_config(config_path="config.json"):
with open(config_path, "r") as f:
config = json.load(f)
return config
def create_default_config():
return {
"model": {
"name": "AKASHA",
"version": "1.0",
"tokenizer": {
"image_size": 256,
"patch_size": 8,
"num_tokens": 1024,
"codebook_dim": 256,
"encoder_hidden_dims": [64, 128, 256, 512],
"decoder_hidden_dims": [512, 256, 128, 64],
"commitment_cost": 0.25,
"num_residual_blocks": 2,
},
"transformer": {
"num_layers": 24,
"d_model": 1024,
"num_heads": 16,
"d_ff": 4096,
"dropout_rate": 0.1,
"max_sequence_length": 1024,
"vocab_size": 1024,
"use_rotary_embeddings": True,
},
"generation": {
"temperature": 0.9,
"top_k": 100,
"top_p": 0.95,
},
},
"training": {
"batch_size": 32,
"learning_rate": 3e-4,
"warmup_steps": 4000,
"total_steps": 500000,
"weight_decay": 0.01,
"gradient_clip_norm": 1.0,
"mixed_precision": True,
"stage1": {"epochs": 100, "learning_rate": 1e-4, "batch_size": 64},
"stage2": {"epochs": 200, "learning_rate": 3e-4, "batch_size": 32},
},
"data": {"dataset": "imagenet", "image_size": 256, "augmentation": True},
"huggingface": {"repo_id": "vedaco/AKASHA", "space_sdk": "gradio"},
}
class CosineDecayWithWarmup(tf.keras.optimizers.schedules.LearningRateSchedule):
def __init__(self, base_lr, warmup_steps, total_steps, min_lr=1e-6):
super().__init__()
self.base_lr = base_lr
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr = min_lr
def __call__(self, step):
step = tf.cast(step, tf.float32)
warmup_lr = self.base_lr * (step / self.warmup_steps)
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
progress = tf.clip_by_value(progress, 0.0, 1.0)
cosine_lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (
1.0 + tf.cos(np.pi * progress)
)
return tf.where(step < self.warmup_steps, warmup_lr, cosine_lr)
def get_config(self):
return {
"base_lr": self.base_lr,
"warmup_steps": self.warmup_steps,
"total_steps": self.total_steps,
"min_lr": self.min_lr,
}
def save_images_grid(images, filepath, grid_size=None):
from PIL import Image as PILImage
if isinstance(images, tf.Tensor):
images = images.numpy()
n = images.shape[0]
if grid_size is None:
grid_size = int(np.ceil(np.sqrt(n)))
h, w = images.shape[1], images.shape[2]
grid = np.zeros((grid_size * h, grid_size * w, 3), dtype=np.uint8)
for i in range(min(n, grid_size * grid_size)):
row, col = i // grid_size, i % grid_size
img = (images[i] * 255).clip(0, 255).astype(np.uint8)
grid[row * h : (row + 1) * h, col * w : (col + 1) * w] = img
PILImage.fromarray(grid).save(filepath)
return filepath
def count_parameters(model):
return sum(np.prod(v.shape) for v in model.trainable_variables)
def get_model_summary(config):
tok = config["model"]["tokenizer"]
trans = config["model"]["transformer"]
grid_size = tok["image_size"] // tok["patch_size"]
seq_len = grid_size * grid_size
print("=" * 60)
print(" AKASHA Model Configuration")
print("=" * 60)
print(f" Image Size: {tok['image_size']}x{tok['image_size']}")
print(f" Patch Size: {tok['patch_size']}x{tok['patch_size']}")
print(f" Grid Size: {grid_size}x{grid_size}")
print(f" Sequence Length: {seq_len} tokens")
print(f" Codebook Size: {tok['num_tokens']}")
print(f" Transformer: {trans['num_layers']}L / {trans['d_model']}D / {trans['num_heads']}H")
print("=" * 60)