Update README.md
Browse files
README.md
CHANGED
|
@@ -2,51 +2,57 @@
|
|
| 2 |
license: apache-2.0
|
| 3 |
---
|
| 4 |
|
| 5 |
-
#
|
| 6 |
|
| 7 |
-
# Single-Stream DiT (Proof-of-Concept)
|
| 8 |
|
| 9 |
-
This repository contains the
|
| 10 |
|
| 11 |
-
The primary objective was to demonstrate the feasibility and training stability of coupling the high-fidelity **
|
| 12 |
|
| 13 |
**Note:** The entire project codebase can be found on the [GitHub](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2) page!
|
| 14 |
|
| 15 |
## Project Overview
|
| 16 |
|
| 17 |
### How it started
|
|
|
|
| 18 |

|
| 19 |
|
| 20 |
### Verification and Result Comparison
|
| 21 |
|
| 22 |
-
| Cached Latent Verification | Final Generated Sample (
|
| 23 |
| :---: | :---: |
|
| 24 |
-
|  |  |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
## Training Progression
|
| 36 |
|
| 37 |
-
| Early Epoch (Epoch
|
| 38 |
| :---: | :---: | :---: |
|
| 39 |
-
|  page!
|
| 13 |
|
| 14 |
## Project Overview
|
| 15 |
|
| 16 |
### How it started
|
| 17 |
+
|
| 18 |

|
| 19 |
|
| 20 |
### Verification and Result Comparison
|
| 21 |
|
| 22 |
+
| Cached Latent Verification | Final Generated Sample (Euler 50 steps and CFG 3.0) |
|
| 23 |
| :---: | :---: |
|
| 24 |
+
|  |  |
|
| 25 |
|
| 26 |
## Core Models and Architecture
|
| 27 |
|
| 28 |
| Component | Model ID / Function | Purpose |
|
| 29 |
| :--- | :--- | :--- |
|
| 30 |
+
| **Generator** | `SingleStreamDiTV2` | Custom Single-Stream DiT featuring Visual Fusion blocks, Context Refiners, and Fourier Filters. DiT Parameters: _768 Hidden Size, 12 Heads, 16 Depth, 2 Refiner Depth, 128 Text Token Legth, 2 Patch Size._ |
|
| 31 |
+
| **Text Encoder** | `google/t5gemma-2-1b-1b` | Generates rich, 1152-dimensional text embeddings for high-quality semantic guidance. |
|
| 32 |
+
| **VAE** | `diffusers/FLUX.1-vae` | A 16-channel VAE with an 8x downsample factor, providing superior reconstruction for complex textures. |
|
| 33 |
+
| **Training Method** | Flow Matching (V-Prediction) | Optimized with a Velocity-based objective and an optional Self-Evaluation (Self-E) consistency loss. |
|
| 34 |
+
|
| 35 |
+
## New in V3
|
| 36 |
+
- **Refinement Stages:** Separate noise and context refiner blocks to "prep" tokens before the joint fusion phase.
|
| 37 |
+
- **Fourier Filters:** Frequency-domain processing layers to improve global structural coherence.
|
| 38 |
+
- **Local Spatial Bias:** Conv2D-based depthwise biases to reinforce local texture within the transformer.
|
| 39 |
+
- **Rotary Embeddings (RoPE):** Dynamic 2D-RoPE grid support for area-preserving bucketing.
|
| 40 |
|
| 41 |
## Training Progression
|
| 42 |
|
| 43 |
+
| Early Epoch (Epoch 25) | Final Epoch (Epoch 1200) | Full Progression |
|
| 44 |
| :---: | :---: | :---: |
|
| 45 |
+
|  |  |  |
|
| 46 |
|
| 47 |
## Data Curation and Preprocessing
|
| 48 |
|
| 49 |
+
The model was tested on a curated dataset of **200 images** (10 categories of flowers) before scaling to larger datasets.
|
| 50 |
|
| 51 |
| Component | Tool / Method | Purpose / Detail |
|
| 52 |
| :--- | :--- | :--- |
|
| 53 |
+
| **Pre/Post-processing** | **[Dataset Helpers](https://github.com/Particle1904/DatasetHelpers)** | Used to resize images (using **[DPID](https://github.com/Mishini/dpid)** - Detail-Preserving Image Downscaling) and edit the Qwen3-VL captions. |
|
| 54 |
+
| **Captioning** | **Qwen3-VL-4B-Instruct** | Captions include precise botanical details: texture (waxy, serrated), plant anatomy (stamen, pistil), and camera lighting. |
|
| 55 |
+
| **Data Encoding** | `preprocess.py` | Encodes images via FLUX-VAE and text via T5Gemma2, applying aspect-ratio bucketing. |
|
| 56 |
|
| 57 |
<details>
|
| 58 |
<summary><h2><b>Qwen3-VL-4B-Instruct System Instruction (Captioning Prompt)</b></h2></summary>
|
|
|
|
| 88 |
Output Format: Output a single string containing the caption, without double quotes, using commas to separate phrases.</i>
|
| 89 |
</details>
|
| 90 |
|
| 91 |
+
## Training History and Configuration
|
| 92 |
|
| 93 |
+
Training utilizes **8-bit AdamW** and a **Cosine Schedule with 5% Warmup** for 1200 (stopped early) epochs using **MSE**.
|
| 94 |
|
| 95 |
+
| Configuration | Value | Purpose |
|
| 96 |
+
| :--- | :--- | :--- |
|
| 97 |
+
| **Loss** | **`MSE at 2e-4`** | Trained with MSE only. |
|
| 98 |
+
| **Batch Size** | **`16`** | Gradient Checkpointing enabled and accumulative steps set to 2. |
|
| 99 |
+
| **Shift Value** | **`1.0` (Uniform)** | Ensures a balanced training across all noise levels, critical for learning geometry on small datasets. |
|
| 100 |
+
| **Latent Norm** | **`0.0 Mean / 1.0 Std`** | Hardcoded identity normalization to preserve the relative channel relationships of the FLUX VAE. **Note:** Using a Mean and Std calculated from the dataset resulted in poor reconstruction with artifacts. |
|
| 101 |
+
| **EMA Decay** | **`0.999`** | Maintains a moving average of weights for smoother, higher-quality inference. |
|
| 102 |
+
| **Self-Evolution** | **`Disabled`** | Optional teacher-student distillation. (**Note:** Not used in this PoC to maintain baseline architectural clarity). |
|
| 103 |
|
| 104 |
+
### Loss & Fourier Gate Progression
|
| 105 |
|
| 106 |
+
| Loss Graph | Fourier Gate |
|
| 107 |
+
| :---: | :---: |
|
| 108 |
+
|  |  |
|
| 109 |
|
| 110 |
**Training Time Estimate:**
|
| 111 |
+
* **GPU Time:** Approximately **6 hours and 21 minutes** of total GPU compute time for 1200 epochs (RTX 5060 Ti 16GB).
|
| 112 |
+
* **Project Time (Human):** 13 days of R&D, including hyperparameter tuning.
|
| 113 |
|
| 114 |
## Reproducibility
|
| 115 |
|
| 116 |
This repository is designed to be fully reproducible. The following data is included in the respective directories:
|
| 117 |
+
* **Raw Dataset:** The original `.png` images and the **Qwen3-VL-4B-Instruct** generated and reviewed `.txt` captions.
|
| 118 |
* **Cached Dataset:** The processed, tokenized, and VAE-encoded latents (`.pt` files).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
## Repository File Breakdown
|
| 121 |
|
| 122 |
+
### Training & Core Scripts
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
| File | Purpose | Notes |
|
| 125 |
+
| :--- | :--- | :--- |
|
| 126 |
+
| **`train.py`** | Main training script. Supports EMA, Self-E, and Gradient Accumulation. | Includes automatic model compilation on Linux. |
|
| 127 |
+
| **`model.py`** | Defines `SingleStreamDiTV2` with Visual Fusion, Fourier Filters, and SwiGLU. | The core architecture definition. |
|
| 128 |
+
| **`config.py`** | Central configuration for paths, model dims, and hyperparameters. | All model settings are controlled here. |
|
| 129 |
+
| **`sanity_check.py`** | A utility to ensure the model can overfit to a single cached latent file. | Used for debugging architecture changes. |
|
| 130 |
|
| 131 |
### Utility & Preprocessing
|
| 132 |
|
| 133 |
+
| File | Purpose | Notes |
|
| 134 |
+
| :--- | :--- | :--- |
|
| 135 |
+
| **`preprocess.py`** | Prepares raw image/text data into cached `.pt` files using VAE and T5. | Run this before starting training. |
|
| 136 |
+
| **`calculate_cache_statistics.py`** | Analyzes cached latents to find Mean/Std for normalization settings. | **Note:** Use results with caution; defaults of 0.0/1.0 are often better. |
|
| 137 |
+
| **`debug_vae_pipeline.py`** | Tests the VAE reconstruction pipeline in float32 to isolate VAE issues. | Useful for troubleshooting color shifts. |
|
| 138 |
+
| **`check_cache.py`** | Decodes a single cached latent back to an image to verify preprocessing. | Fast integrity check. |
|
| 139 |
+
| **`generate_graph.py`** | Generates the loss curve visualization from the training CSV logs. | Creates `loss_curve.png`. |
|
| 140 |
|
| 141 |
+
### Inference & Data
|
| 142 |
|
| 143 |
+
| File | Purpose | Notes |
|
| 144 |
+
| :--- | :--- | :--- |
|
| 145 |
+
| **`inferenceNotebook.ipynb`** | Primary inference tool. Supports text-to-image with Euler/RK4. | Best for interactive testing. |
|
| 146 |
+
| **`samplers.py`** | Numerical integration steps for Euler and Runge-Kutta 4 (RK4). | Logic for the flow matching inference. |
|
| 147 |
+
| **`latents.py`** | Scaling and normalization logic for VAE latents. | Shared across preprocess, train, and inference. |
|
| 148 |
+
| **`dataset.py`** | Bucket-batching and RAM-caching dataset implementation. | Handles the training data pipeline. |
|