""" Training script for Text-to-YOLO-Weights Hypernetwork. Based on DnD (Drag-and-Drop LLMs) architecture with p-diff noise augmentation. Pipeline: 1. Load pre-generated dataset of (text_description, LoRA_adapter_vector) 2. Train hyper-convolutional decoder to predict adapter weights from text embeddings 3. Validate generated weights by measuring MSE and (optionally) running YOLO inference """ import os import json import random import argparse from typing import Dict, List, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from sentence_transformers import SentenceTransformer # --- Configuration --- class Config: text_encoder_model: str = "sentence-transformers/all-MiniLM-L6-v2" text_embed_dim: int = 384 decoder_hidden_dims: List[int] = [1024, 2048, 4096, 2048, 1024] num_tokens: int = 64 # sequence length for 1D conv lora_r: int = 16 batch_size: int = 4 lr: float = 1e-4 num_epochs: int = 100 weight_noise_scale: float = 0.001 latent_noise_scale: float = 0.1 dataset_path: str = "./text_to_yolo_dataset/text_to_yolo_dataset.json" output_dir: str = "./text_to_yolo_output" # Trackio trackio_project: str = "text-to-yolo-weights" trackio_space_id: str = "mabbam/text-to-yolo-trackio" # --- Hyper-Convolutional Decoder --- class HyperConvBlock(nn.Module): def __init__(self, in_dim: int, out_dim: int, kernel_size: int = 3): super().__init__() self.conv1 = nn.Conv1d(in_dim, out_dim, kernel_size, padding=kernel_size // 2) self.conv2 = nn.Conv1d(out_dim, out_dim, kernel_size, padding=kernel_size // 2) self.conv3 = nn.Conv1d(out_dim, out_dim, kernel_size, padding=kernel_size // 2) self.norm1 = nn.GroupNorm(8, out_dim) self.norm2 = nn.GroupNorm(8, out_dim) self.norm3 = nn.GroupNorm(8, out_dim) self.skip = nn.Conv1d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() def forward(self, x): residual = self.skip(x) x = F.gelu(self.norm1(self.conv1(x))) x = F.gelu(self.norm2(self.conv2(x))) x = self.norm3(self.conv3(x)) return x + residual class HyperWeightDecoder(nn.Module): def __init__(self, config: Config, layer_shapes: Dict[str, Tuple[int, int]]): super().__init__() self.config = config self.layer_shapes = layer_shapes self.layer_names = list(layer_shapes.keys()) # Total parameters to generate: for each layer, A (in_f, r) + B (out_f, r) self.total_params = 0 self.param_info = {} for name, (out_f, in_f) in layer_shapes.items(): a_size = in_f * config.lora_r b_size = out_f * config.lora_r self.param_info[name] = {"offset": self.total_params, "a_size": a_size, "in_f": in_f, "out_f": out_f} self.total_params += a_size + b_size # Text embedding projection to conv sequence self.text_proj = nn.Linear(config.text_embed_dim, config.num_tokens * config.decoder_hidden_dims[0]) # Cascaded hyper-convolution blocks dims = [config.decoder_hidden_dims[0]] + config.decoder_hidden_dims self.blocks = nn.ModuleList([ HyperConvBlock(dims[i], dims[i+1]) for i in range(len(dims)-1) ]) # Final head self.head = nn.Sequential( nn.Linear(dims[-1] * config.num_tokens, 8192), nn.GELU(), nn.LayerNorm(8192), nn.Linear(8192, self.total_params), ) def forward(self, text_emb: torch.Tensor, add_noise: bool = True): B = text_emb.size(0) x = self.text_proj(text_emb).view(B, self.config.decoder_hidden_dims[0], self.config.num_tokens) for block in self.blocks: x = block(x) x = x.view(B, -1) weights = self.head(x) if self.training and add_noise: weights = weights + torch.randn_like(weights) * self.config.weight_noise_scale # Reshape into per-layer LoRA A/B adapters = {} for name in self.layer_names: info = self.param_info[name] r = self.config.lora_r w = weights[:, info["offset"]:info["offset"] + info["a_size"] + info["out_f"] * r] a = w[:, :info["a_size"]].view(B, info["in_f"], r) b = w[:, info["a_size"]:].view(B, info["out_f"], r) adapters[name] = (a, b) return adapters, weights # --- Dataset --- class TextToYoloDataset(Dataset): def __init__(self, dataset_path: str): with open(dataset_path, "r") as f: self.data = json.load(f) print(f"Loaded dataset with {len(self.data)} samples") def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] prompt = sample["description"] weights = torch.tensor(sample["weight_vector"], dtype=torch.float32) return prompt, weights # --- Training --- def train(config: Config): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load dataset dataset = TextToYoloDataset(config.dataset_path) dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) # Load layer shapes from metadata shapes_path = os.path.join(os.path.dirname(config.dataset_path), "lora_shapes.json") with open(shapes_path, "r") as f: layer_shapes = json.load(f) # Convert to tuples layer_shapes = {k: tuple(v) for k, v in layer_shapes.items()} # Initialize models print("Loading text encoder...") text_encoder = SentenceTransformer(config.text_encoder_model).to(device) for p in text_encoder.parameters(): p.requires_grad = False print(f"Initializing decoder for {len(layer_shapes)} layers, {sum(v[0]*v[1] for v in layer_shapes.values())} base params...") decoder = HyperWeightDecoder(config, layer_shapes).to(device) print(f"Decoder trainable params: {sum(p.numel() for p in decoder.parameters()):,}") print(f"Target weight vector size: {decoder.total_params:,}") optimizer = torch.optim.AdamW(decoder.parameters(), lr=config.lr, weight_decay=0.01) scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) # Optional trackio try: import trackio trackio.init(project=config.trackio_project, space_id=config.trackio_space_id) print("Trackio initialized") use_trackio = True except ImportError: use_trackio = False print("Trackio not available") os.makedirs(config.output_dir, exist_ok=True) best_loss = float("inf") for epoch in range(config.num_epochs): decoder.train() total_loss = 0.0 num_batches = 0 for prompts, targets in dataloader: targets = targets.to(device) with torch.no_grad(): text_emb = text_encoder.encode(prompts, convert_to_tensor=True, show_progress_bar=False) text_emb = text_emb.to(device) _, pred_weights = decoder(text_emb) # Latent noise augmentation (p-diff style) if config.latent_noise_scale > 0: pred_weights = pred_weights + torch.randn_like(pred_weights) * config.latent_noise_scale loss = F.mse_loss(pred_weights, targets) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1.0) optimizer.step() total_loss += loss.item() num_batches += 1 avg_loss = total_loss / max(num_batches, 1) scheduler.step() print(f"Epoch {epoch+1}/{config.num_epochs} | Loss: {avg_loss:.6f} | LR: {scheduler.get_last_lr()[0]:.2e}") if use_trackio: trackio.log({"loss": avg_loss, "epoch": epoch, "lr": scheduler.get_last_lr()[0]}) if avg_loss < best_loss: best_loss = avg_loss torch.save({ "decoder": decoder.state_dict(), "config": vars(config), "layer_shapes": layer_shapes, "epoch": epoch, "loss": avg_loss, }, os.path.join(config.output_dir, "best_decoder.pt")) print(f"Training complete. Best loss: {best_loss:.6f}") print(f"Saved to {config.output_dir}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--dataset_path", default="./text_to_yolo_dataset/text_to_yolo_dataset.json") parser.add_argument("--output_dir", default="./text_to_yolo_output") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--epochs", type=int, default=100) args = parser.parse_args() config = Config() config.dataset_path = args.dataset_path config.output_dir = args.output_dir config.batch_size = args.batch_size config.lr = args.lr config.num_epochs = args.epochs train(config)