| """ |
| Training script for Text-to-YOLO-Weights Hypernetwork. |
| Based on DnD (Drag-and-Drop LLMs) architecture with p-diff noise augmentation. |
| |
| Pipeline: |
| 1. Load pre-generated dataset of (text_description, LoRA_adapter_vector) |
| 2. Train hyper-convolutional decoder to predict adapter weights from text embeddings |
| 3. Validate generated weights by measuring MSE and (optionally) running YOLO inference |
| """ |
| import os |
| import json |
| import random |
| import argparse |
| from typing import Dict, List, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from sentence_transformers import SentenceTransformer |
|
|
|
|
| |
| class Config: |
| text_encoder_model: str = "sentence-transformers/all-MiniLM-L6-v2" |
| text_embed_dim: int = 384 |
| |
| decoder_hidden_dims: List[int] = [1024, 2048, 4096, 2048, 1024] |
| num_tokens: int = 64 |
| lora_r: int = 16 |
| |
| batch_size: int = 4 |
| lr: float = 1e-4 |
| num_epochs: int = 100 |
| weight_noise_scale: float = 0.001 |
| latent_noise_scale: float = 0.1 |
| |
| dataset_path: str = "./text_to_yolo_dataset/text_to_yolo_dataset.json" |
| output_dir: str = "./text_to_yolo_output" |
| |
| |
| trackio_project: str = "text-to-yolo-weights" |
| trackio_space_id: str = "mabbam/text-to-yolo-trackio" |
|
|
|
|
| |
| class HyperConvBlock(nn.Module): |
| def __init__(self, in_dim: int, out_dim: int, kernel_size: int = 3): |
| super().__init__() |
| self.conv1 = nn.Conv1d(in_dim, out_dim, kernel_size, padding=kernel_size // 2) |
| self.conv2 = nn.Conv1d(out_dim, out_dim, kernel_size, padding=kernel_size // 2) |
| self.conv3 = nn.Conv1d(out_dim, out_dim, kernel_size, padding=kernel_size // 2) |
| self.norm1 = nn.GroupNorm(8, out_dim) |
| self.norm2 = nn.GroupNorm(8, out_dim) |
| self.norm3 = nn.GroupNorm(8, out_dim) |
| self.skip = nn.Conv1d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() |
| |
| def forward(self, x): |
| residual = self.skip(x) |
| x = F.gelu(self.norm1(self.conv1(x))) |
| x = F.gelu(self.norm2(self.conv2(x))) |
| x = self.norm3(self.conv3(x)) |
| return x + residual |
|
|
|
|
| class HyperWeightDecoder(nn.Module): |
| def __init__(self, config: Config, layer_shapes: Dict[str, Tuple[int, int]]): |
| super().__init__() |
| self.config = config |
| self.layer_shapes = layer_shapes |
| self.layer_names = list(layer_shapes.keys()) |
| |
| |
| self.total_params = 0 |
| self.param_info = {} |
| for name, (out_f, in_f) in layer_shapes.items(): |
| a_size = in_f * config.lora_r |
| b_size = out_f * config.lora_r |
| self.param_info[name] = {"offset": self.total_params, "a_size": a_size, "in_f": in_f, "out_f": out_f} |
| self.total_params += a_size + b_size |
| |
| |
| self.text_proj = nn.Linear(config.text_embed_dim, config.num_tokens * config.decoder_hidden_dims[0]) |
| |
| |
| dims = [config.decoder_hidden_dims[0]] + config.decoder_hidden_dims |
| self.blocks = nn.ModuleList([ |
| HyperConvBlock(dims[i], dims[i+1]) for i in range(len(dims)-1) |
| ]) |
| |
| |
| self.head = nn.Sequential( |
| nn.Linear(dims[-1] * config.num_tokens, 8192), |
| nn.GELU(), |
| nn.LayerNorm(8192), |
| nn.Linear(8192, self.total_params), |
| ) |
| |
| def forward(self, text_emb: torch.Tensor, add_noise: bool = True): |
| B = text_emb.size(0) |
| x = self.text_proj(text_emb).view(B, self.config.decoder_hidden_dims[0], self.config.num_tokens) |
| |
| for block in self.blocks: |
| x = block(x) |
| |
| x = x.view(B, -1) |
| weights = self.head(x) |
| |
| if self.training and add_noise: |
| weights = weights + torch.randn_like(weights) * self.config.weight_noise_scale |
| |
| |
| adapters = {} |
| for name in self.layer_names: |
| info = self.param_info[name] |
| r = self.config.lora_r |
| w = weights[:, info["offset"]:info["offset"] + info["a_size"] + info["out_f"] * r] |
| |
| a = w[:, :info["a_size"]].view(B, info["in_f"], r) |
| b = w[:, info["a_size"]:].view(B, info["out_f"], r) |
| adapters[name] = (a, b) |
| |
| return adapters, weights |
|
|
|
|
| |
| class TextToYoloDataset(Dataset): |
| def __init__(self, dataset_path: str): |
| with open(dataset_path, "r") as f: |
| self.data = json.load(f) |
| print(f"Loaded dataset with {len(self.data)} samples") |
| |
| def __len__(self): |
| return len(self.data) |
| |
| def __getitem__(self, idx): |
| sample = self.data[idx] |
| prompt = sample["description"] |
| weights = torch.tensor(sample["weight_vector"], dtype=torch.float32) |
| return prompt, weights |
|
|
|
|
| |
| def train(config: Config): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| |
| |
| dataset = TextToYoloDataset(config.dataset_path) |
| dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) |
| |
| |
| shapes_path = os.path.join(os.path.dirname(config.dataset_path), "lora_shapes.json") |
| with open(shapes_path, "r") as f: |
| layer_shapes = json.load(f) |
| |
| layer_shapes = {k: tuple(v) for k, v in layer_shapes.items()} |
| |
| |
| print("Loading text encoder...") |
| text_encoder = SentenceTransformer(config.text_encoder_model).to(device) |
| for p in text_encoder.parameters(): |
| p.requires_grad = False |
| |
| print(f"Initializing decoder for {len(layer_shapes)} layers, {sum(v[0]*v[1] for v in layer_shapes.values())} base params...") |
| decoder = HyperWeightDecoder(config, layer_shapes).to(device) |
| print(f"Decoder trainable params: {sum(p.numel() for p in decoder.parameters()):,}") |
| print(f"Target weight vector size: {decoder.total_params:,}") |
| |
| optimizer = torch.optim.AdamW(decoder.parameters(), lr=config.lr, weight_decay=0.01) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) |
| |
| |
| try: |
| import trackio |
| trackio.init(project=config.trackio_project, space_id=config.trackio_space_id) |
| print("Trackio initialized") |
| use_trackio = True |
| except ImportError: |
| use_trackio = False |
| print("Trackio not available") |
| |
| os.makedirs(config.output_dir, exist_ok=True) |
| |
| best_loss = float("inf") |
| for epoch in range(config.num_epochs): |
| decoder.train() |
| total_loss = 0.0 |
| num_batches = 0 |
| |
| for prompts, targets in dataloader: |
| targets = targets.to(device) |
| |
| with torch.no_grad(): |
| text_emb = text_encoder.encode(prompts, convert_to_tensor=True, show_progress_bar=False) |
| text_emb = text_emb.to(device) |
| |
| _, pred_weights = decoder(text_emb) |
| |
| |
| if config.latent_noise_scale > 0: |
| pred_weights = pred_weights + torch.randn_like(pred_weights) * config.latent_noise_scale |
| |
| loss = F.mse_loss(pred_weights, targets) |
| |
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(decoder.parameters(), 1.0) |
| optimizer.step() |
| |
| total_loss += loss.item() |
| num_batches += 1 |
| |
| avg_loss = total_loss / max(num_batches, 1) |
| scheduler.step() |
| |
| print(f"Epoch {epoch+1}/{config.num_epochs} | Loss: {avg_loss:.6f} | LR: {scheduler.get_last_lr()[0]:.2e}") |
| |
| if use_trackio: |
| trackio.log({"loss": avg_loss, "epoch": epoch, "lr": scheduler.get_last_lr()[0]}) |
| |
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| torch.save({ |
| "decoder": decoder.state_dict(), |
| "config": vars(config), |
| "layer_shapes": layer_shapes, |
| "epoch": epoch, |
| "loss": avg_loss, |
| }, os.path.join(config.output_dir, "best_decoder.pt")) |
| |
| print(f"Training complete. Best loss: {best_loss:.6f}") |
| print(f"Saved to {config.output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--dataset_path", default="./text_to_yolo_dataset/text_to_yolo_dataset.json") |
| parser.add_argument("--output_dir", default="./text_to_yolo_output") |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--lr", type=float, default=1e-4) |
| parser.add_argument("--epochs", type=int, default=100) |
| args = parser.parse_args() |
| |
| config = Config() |
| config.dataset_path = args.dataset_path |
| config.output_dir = args.output_dir |
| config.batch_size = args.batch_size |
| config.lr = args.lr |
| config.num_epochs = args.epochs |
| |
| train(config) |
|
|