|
|
--- |
|
|
license: apache-2.0 |
|
|
--- |
|
|
|
|
|
|
|
|
# Single-Stream DiT with Global Fourier Filters (Proof-of-Concept) |
|
|
|
|
|
This repository contains the codebase for a Single-Stream Diffusion Transformer (DiT) Proof-of-Concept, heavily inspired by modern architectures like **FLUX.1**, **Z-Image**, and **Lumina Image 2**. |
|
|
|
|
|
The primary objective was to demonstrate the feasibility and training stability of coupling the high-fidelity **FLUX.1-VAE** with the powerful **T5Gemma2** text encoder for image generation on consumer-grade hardware (NVIDIA RTX 5060 Ti 16GB). |
|
|
|
|
|
**Note:** The entire project codebase can be found on the [GitHub](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2) page! |
|
|
|
|
|
## Project Overview |
|
|
|
|
|
### How it started |
|
|
|
|
|
 |
|
|
|
|
|
### Verification and Result Comparison |
|
|
|
|
|
| Cached Latent Verification | Final Generated Sample (Euler 50 steps and CFG 3.0) | |
|
|
| :---: | :---: | |
|
|
|  |  | |
|
|
|
|
|
## Core Models and Architecture |
|
|
|
|
|
| Component | Model ID / Function | Purpose | |
|
|
| :--- | :--- | :--- | |
|
|
| **Generator** | `SingleStreamDiTV2` | Custom Single-Stream DiT featuring Visual Fusion blocks, Context Refiners, and Fourier Filters. DiT Parameters: _768 Hidden Size, 12 Heads, 16 Depth, 2 Refiner Depth, 128 Text Token Legth, 2 Patch Size._ | |
|
|
| **Text Encoder** | `google/t5gemma-2-1b-1b` | Generates rich, 1152-dimensional text embeddings for high-quality semantic guidance. | |
|
|
| **VAE** | `diffusers/FLUX.1-vae` | A 16-channel VAE with an 8x downsample factor, providing superior reconstruction for complex textures. | |
|
|
| **Training Method** | Flow Matching (V-Prediction) | Optimized with a Velocity-based objective and an optional Self-Evaluation (Self-E) consistency loss. | |
|
|
|
|
|
## New in V3 |
|
|
- **Refinement Stages:** Separate noise and context refiner blocks to "prep" tokens before the joint fusion phase. |
|
|
- **Fourier Filters:** Frequency-domain processing layers to improve global structural coherence. |
|
|
- **Local Spatial Bias:** Conv2D-based depthwise biases to reinforce local texture within the transformer. |
|
|
- **Rotary Embeddings (RoPE):** Dynamic 2D-RoPE grid support for area-preserving bucketing. |
|
|
|
|
|
## Training Progression |
|
|
|
|
|
| Early Epoch (Epoch 25) | Final Epoch (Epoch 1200) | Full Progression | |
|
|
| :---: | :---: | :---: | |
|
|
|  |  |  | |
|
|
|
|
|
## Data Curation and Preprocessing |
|
|
|
|
|
The model was tested on a curated dataset of **200 images** (10 categories of flowers) before scaling to larger datasets. |
|
|
|
|
|
| Component | Tool / Method | Purpose / Detail | |
|
|
| :--- | :--- | :--- | |
|
|
| **Pre/Post-processing** | **[Dataset Helpers](https://github.com/Particle1904/DatasetHelpers)** | Used to resize images (using **[DPID](https://github.com/Mishini/dpid)** - Detail-Preserving Image Downscaling) and edit the Qwen3-VL captions. | |
|
|
| **Captioning** | **Qwen3-VL-4B-Instruct** | Captions include precise botanical details: texture (waxy, serrated), plant anatomy (stamen, pistil), and camera lighting. | |
|
|
| **Data Encoding** | `preprocess.py` | Encodes images via FLUX-VAE and text via T5Gemma2, applying aspect-ratio bucketing. | |
|
|
|
|
|
<details> |
|
|
<summary><h2><b>Qwen3-VL-4B-Instruct System Instruction (Captioning Prompt)</b></h2></summary> |
|
|
<i>You are a specialized botanical image analysis system operating within a research environment. Your task is to generate concise, scientifically accurate, and visually descriptive captions for flower images. All output must be strictly factual, objective, and devoid of non-visual assumptions. |
|
|
|
|
|
Your task is to generate captions for images based on the visual content and a provided reference flower category name. Captions must be precise, comprehensive, and meticulously aligned with the visual details of the plant structure, color gradients, and lighting. |
|
|
|
|
|
Caption Style: Generate concise captions that are no more than 50 words. Focus on combining descriptors into brief phrases (separated by commas). Follow this structure: "A \<view type\> of a \<flower name\>, having \<petal details\>, the center is \<center details\>, the background is \<background description\>, \<lighting/style information\>" |
|
|
|
|
|
Hierarchical Description: Begin with the flower name and its primary state (blooming, budding, wilting). Move to the petals (color, shape, texture), then the reproductive parts (stamen, pistil, pollen), then the stem/leaves, and finally the environment. |
|
|
|
|
|
Factual Accuracy & Label Verification: The provided "Input Flower Name" is a reference tag. You must visually verify this tag against the image content. |
|
|
* Match: If the visual features match the tag, use the provided name. |
|
|
* Correction: If the visual characteristics definitively belong to a different species (e.g., input says "Sunflower" but the image clearly shows a "Rose"), you must override the input and use the visually correct botanical name in the caption. |
|
|
* Ambiguity: If the species is unclear, describe the visual features precisely without forcing a specific name. |
|
|
|
|
|
Precise Botanical Terminology: Use correct terminology for plant anatomy. |
|
|
* Petals: Describe edges (serrated, smooth, ruffled), texture (velvety, waxy, delicate), and arrangement (overlapping, sparse, symmetrical). |
|
|
* Center: Use terms like "stamen", "pistil", "anthers", "pollen", "cone", or "disk" when visible. |
|
|
* Leaves/Stem: Describe shape (lance-shaped, oval), arrangement, and surface (glossy, hairy, thorny). |
|
|
|
|
|
Color and Texture: Be specific about colors. Do not just say "pink"; use "pale pink fading to white at the edges", "vibrant magenta", or "speckled purple". Describe patterns like "veining", "spots", "stripes", or "gradients". |
|
|
|
|
|
Condition and State: Describe the physical state of the flower. Examples: "fully in bloom", "closed bud", "drooping petals", "withered edges", or "covered in dew droplets". |
|
|
|
|
|
Environmental Description: Describe the setting strictly as seen. Examples: "green leafy background", "blurry garden setting", "studio black background", "natural sunlight", "dirt ground". |
|
|
|
|
|
Camera Perspective and Style: Crucial for DiT training. Specify: |
|
|
* Shot Type: "Extreme close-up", "macro shot", "eye-level shot", "top-down view". |
|
|
* Focus: "Shallow depth of field", "bokeh background", "sharp focus", "soft focus". |
|
|
* Lighting: "Natural lighting", "harsh shadows", "dappled sunlight", "studio lighting". |
|
|
|
|
|
Output Format: Output a single string containing the caption, without double quotes, using commas to separate phrases.</i> |
|
|
</details> |
|
|
|
|
|
## Training History and Configuration |
|
|
|
|
|
Training utilizes **8-bit AdamW** and a **Cosine Schedule with 5% Warmup** for 1200 (stopped early) epochs using **MSE**. |
|
|
|
|
|
| Configuration | Value | Purpose | |
|
|
| :--- | :--- | :--- | |
|
|
| **Loss** | **`MSE at 2e-4`** | Trained with MSE only. | |
|
|
| **Batch Size** | **`16`** | Gradient Checkpointing enabled and accumulative steps set to 2. | |
|
|
| **Shift Value** | **`1.0` (Uniform)** | Ensures a balanced training across all noise levels, critical for learning geometry on small datasets. | |
|
|
| **Latent Norm** | **`0.0 Mean / 1.0 Std`** | Hardcoded identity normalization to preserve the relative channel relationships of the FLUX VAE. **Note:** Using a Mean and Std calculated from the dataset resulted in poor reconstruction with artifacts. | |
|
|
| **EMA Decay** | **`0.999`** | Maintains a moving average of weights for smoother, higher-quality inference. | |
|
|
| **Self-Evolution** | **`Disabled`** | Optional teacher-student distillation. (**Note:** Not used in this PoC to maintain baseline architectural clarity). | |
|
|
|
|
|
### Loss & Fourier Gate Progression |
|
|
|
|
|
| Loss Graph | Fourier Gate | |
|
|
| :---: | :---: | |
|
|
|  |  | |
|
|
|
|
|
**Training Time Estimate:** |
|
|
* **GPU Time:** Approximately **6 hours and 21 minutes** of total GPU compute time for 1200 epochs (RTX 5060 Ti 16GB). |
|
|
* **Project Time (Human):** 13 days of R&D, including hyperparameter tuning. |
|
|
|
|
|
## Reproducibility |
|
|
|
|
|
This repository is designed to be fully reproducible. The following data is included in the respective directories: |
|
|
* **Raw Dataset:** The original `.png` images and the **Qwen3-VL-4B-Instruct** generated and reviewed `.txt` captions. |
|
|
* **Cached Dataset:** The processed, tokenized, and VAE-encoded latents (`.pt` files). |
|
|
|
|
|
## Repository File Breakdown |
|
|
|
|
|
### Training & Core Scripts |
|
|
|
|
|
| File | Purpose | Notes | |
|
|
| :--- | :--- | :--- | |
|
|
| **`train.py`** | Main training script. Supports EMA, Self-E, and Gradient Accumulation. | Includes automatic model compilation on Linux. | |
|
|
| **`model.py`** | Defines `SingleStreamDiTV2` with Visual Fusion, Fourier Filters, and SwiGLU. | The core architecture definition. | |
|
|
| **`config.py`** | Central configuration for paths, model dims, and hyperparameters. | All model settings are controlled here. | |
|
|
| **`sanity_check.py`** | A utility to ensure the model can overfit to a single cached latent file. | Used for debugging architecture changes. | |
|
|
|
|
|
### Utility & Preprocessing |
|
|
|
|
|
| File | Purpose | Notes | |
|
|
| :--- | :--- | :--- | |
|
|
| **`preprocess.py`** | Prepares raw image/text data into cached `.pt` files using VAE and T5. | Run this before starting training. | |
|
|
| **`calculate_cache_statistics.py`** | Analyzes cached latents to find Mean/Std for normalization settings. | **Note:** Use results with caution; defaults of 0.0/1.0 are often better. | |
|
|
| **`debug_vae_pipeline.py`** | Tests the VAE reconstruction pipeline in float32 to isolate VAE issues. | Useful for troubleshooting color shifts. | |
|
|
| **`check_cache.py`** | Decodes a single cached latent back to an image to verify preprocessing. | Fast integrity check. | |
|
|
| **`generate_graph.py`** | Generates the loss curve visualization from the training CSV logs. | Creates `loss_curve.png`. | |
|
|
|
|
|
### Inference & Data |
|
|
|
|
|
| File | Purpose | Notes | |
|
|
| :--- | :--- | :--- | |
|
|
| **`inferenceNotebook.ipynb`** | Primary inference tool. Supports text-to-image with Euler/RK4. | Best for interactive testing. | |
|
|
| **`samplers.py`** | Numerical integration steps for Euler and Runge-Kutta 4 (RK4). | Logic for the flow matching inference. | |
|
|
| **`latents.py`** | Scaling and normalization logic for VAE latents. | Shared across preprocess, train, and inference. | |
|
|
| **`dataset.py`** | Bucket-batching and RAM-caching dataset implementation. | Handles the training data pipeline. | |
|
|
|