Crowlley commited on
Commit
fbcff68
·
verified ·
1 Parent(s): 3a396c1

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -58
README.md CHANGED
@@ -2,51 +2,57 @@
2
  license: apache-2.0
3
  ---
4
 
5
- #
6
 
7
- # Single-Stream DiT (Proof-of-Concept)
8
 
9
- This repository contains the final Checkpoint for a Single-Stream Diffusion Transformer (DiT) Proof-of-Concept, heavily inspired by modern architectures like **Z-Image** and **Lumina Image 2**.
10
 
11
- The primary objective was to demonstrate the feasibility and training stability of coupling the high-fidelity **EQ-SDXL-VAE** with the powerful **T5Gemma2** text encoder for image generation on consumer-grade hardware (NVIDIA RTX 5060 Ti 16GB).
12
 
13
  **Note:** The entire project codebase can be found on the [GitHub](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2) page!
14
 
15
  ## Project Overview
16
 
17
  ### How it started
 
18
  ![How it started](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/bored.png?raw=true)
19
 
20
  ### Verification and Result Comparison
21
 
22
- | Cached Latent Verification | Final Generated Sample (RK4, Epoch 1200 EMA) |
23
  | :---: | :---: |
24
- | ![Cache Verification](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/cache_verification.png?raw=true) | ![Generated Sample](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/sample_rk4_cfg1.5.png?raw=true) |
25
 
26
  ## Core Models and Architecture
27
 
28
  | Component | Model ID / Function | Purpose |
29
  | :--- | :--- | :--- |
30
- | **Generator** | `SingleStreamDiTV2` | Custom Diffusion Transformer model that predicts the velocity vector ($v$) along the flow trajectory. |
31
- | **Text Encoder** | `google/t5gemma-2-1b-1b` | Used to generate rich, high-dimensional text embeddings (1152 dimensions) for conditional guidance (CFG). |
32
- | **VAE** | `KBlueLeaf/EQ-SDXL-VAE` | A high-quality SDXL-compatible VAE used for latent compression and reconstruction. Crucial for handling fine details. |
33
- | **Training Method** | Flow Matching (V-Prediction) | The model is trained to predict the vector field that transforms noise ($x_0$) to the clean latent ($x_1$). |
 
 
 
 
 
 
34
 
35
  ## Training Progression
36
 
37
- | Early Epoch (Epoch 10) | Final Epoch (Epoch 1200, RAW) | Full Progression (GIF) |
38
  | :---: | :---: | :---: |
39
- | ![Epoch10](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/epoch_10.png?raw=true) | ![Epoch1200](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/epoch_1200.png?raw=true) | ![Epochs over time](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/epochsOverTime.gif?raw=true) |
40
 
41
  ## Data Curation and Preprocessing
42
 
43
- The model was trained on a small, curated dataset of **200 images** (10 categories of flowers with 20 images each).
44
 
45
  | Component | Tool / Method | Purpose / Detail |
46
  | :--- | :--- | :--- |
47
- | **Pre/Post-processing of Dataset** | **[Dataset Tools](https://github.com/Particle1904/DatasetHelpers)** | Used to resize images to 512x512 and to Edit the Qwen3-VL captions. |
48
- | **Captioning** | **Qwen3-VL-4B-Instruct** | The dataset was captioned using a specialized visual language model with a strict botanical system instruction. This ensured captions contained precise details on **texture (waxy, serrated), plant anatomy (stamen, pistil), lighting (shallow depth of field), and shot type (macro shot)**. |
49
- | **Data Encoding** | `preprocess.py` | Encodes images to latents via EQ-VAE and text via T5Gemma2, applying bucketing and horizontal flip augmentation (for Epochs 0-600). |
50
 
51
  <details>
52
  <summary><h2><b>Qwen3-VL-4B-Instruct System Instruction (Captioning Prompt)</b></h2></summary>
@@ -82,65 +88,61 @@ Camera Perspective and Style: Crucial for DiT training. Specify:
82
  Output Format: Output a single string containing the caption, without double quotes, using commas to separate phrases.</i>
83
  </details>
84
 
85
- ## Training History and Final Configuration
86
 
87
- Training utilized a **Cosine Annealing Learning Rate Scheduler** across all epochs to facilitate steady convergence, starting with $1e-4$ and ending at $1e-5$.
88
 
89
- | Epoch Range | Loss Function | Learning Rate (Start $\to$ End) | Key Features | Observation |
90
- | :--- | :--- | :--- | :--- | :--- |
91
- | **0 - 600** | Mean Squared Error (MSE) | $1e-4 \to \sim 5e-5$ | Horizontal Flip Augmentation **Enabled**. | Fast convergence on shape, but resulted in an undesirable "waxy" finish. At Epoch 600, the **Horizontal Flip Augmentation was removed** due to it causing artifacts (e.g., generating two flower stems). |
92
- | **601 - 900** | L1 Loss (MAE) | $5e-5 \to \sim 2e-5$ | Horizontal Flip **Disabled**. Switched to L1 Loss. | Modest improvement in sharpness and clarity. |
93
- | **901 - 1200** | L1 Loss (MAE) | $5e-5 \to 1e-5$ | Horizontal Flip **Disabled**. **Introduced EMA** + Latent Normalization to $\text{Std} \approx 1.0$. | **Optimal result.** Eliminated "waxy" look, successfully recovering and sharpening the high-frequency textural details. |
 
 
 
94
 
95
- ### Loss Progression
96
 
97
- ![Loss Graph](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/loss_curve.png?raw=true)
 
 
98
 
99
  **Training Time Estimate:**
100
- * **GPU Time:** Approximately **5 hours** of total GPU compute time for 1200 epochs (based on an average epoch time of $\sim 14$ seconds).
101
- * **Project Time (Human):** The overall development and hyperparameter tuning project took approximately 2-3 days.
102
 
103
  ## Reproducibility
104
 
105
  This repository is designed to be fully reproducible. The following data is included in the respective directories:
106
- * **Raw Dataset:** The original `.png` images and the **Qwen3-VL-4B-Instruct** generated `.txt` captions.
107
  * **Cached Dataset:** The processed, tokenized, and VAE-encoded latents (`.pt` files).
108
- * **Training Artifacts:** All checkpoint samples (from Epoch 10 to 1200) and all training log files (split by epoch range and a combined 0-1200 file).
109
-
110
- ### Key Configuration in `train_v3_ema.py`
111
-
112
- | Configuration | Value | Purpose |
113
- | :--- | :--- | :--- |
114
- | **Loss** | `F.l1_loss(pred, target_v)` | Uniform L1/MAE loss, which proved superior to MSE for generating fine details. |
115
- | **Latent Norm** | `x_1 = x_1 / 1.1908` | Normalizes latents to $\text{Std} \approx 1.0$ for optimal neural network stability. |
116
- | **EMA Decay** | `EMA_DECAY = 0.999` | Ensures a stable, high-quality checkpoint is saved, preventing weight oscillation. **The final result uses the EMA weights.** |
117
 
118
  ## Repository File Breakdown
119
 
120
- This section details the purpose and configurable parameters of each primary Python file.
121
-
122
- ### Training Scripts
123
 
124
- | File | Purpose | Key Configs | Notes |
125
- | :--- | :--- | :--- | :--- |
126
- | **`train_v3_ema.py`** | **Final, Optimal Training Script.** Uses L1 Loss, Latent Normalization, Cosine LR Annealing, and **EMA**. | `RESUME_FROM`, `START_EPOCH`, `BATCH_SIZE`, `LEARNING_RATE`, `EMA_DECAY` | This is the recommended script for any new training runs. |
127
- | **`train_v2_l1.py`** | Archive: Training script for epochs 601-900 (L1 Loss only). | `RESUME_FROM`, `START_EPOCH` | **DEPRECATED.** |
128
- | **`train.py`** | Archive: Initial training script for epochs 0-600 (MSE Loss, basic, **with flip augmentation**). | `RESUME_FROM`, `START_EPOCH` | **DEPRECATED.** |
129
- | **`train_overfit.py`** | A sanity check utility to ensure the model can overfit to a single data point. | `TARGET_FILE`, `STEPS` | For debugging architecture changes only. |
130
 
131
  ### Utility & Preprocessing
132
 
133
- | File | Purpose | Key Configs | Notes |
134
- | :--- | :--- | :--- | :--- |
135
- | **`preprocess.py`** | Prepares the raw image/text data into cached `.pt` files. Encodes images to latents via EQ-VAE and text via T5Gemma2. | `DATASET_DIR`, `OUTPUT_DIR`, `BUCKETS` | Must be run once before training. Includes image-flipping data augmentation (used only in Epochs 0-600 training). |
136
- | **`calculate_cache_statistics.py`** | Analyzes all cached `.pt` files to report the dataset's Mean and Standard Deviation. | `CACHE_DIR` | **CRITICAL** for determining the `LATENT_STD_SCALE` and `LATENT_OFFSET` used in training/inference. |
137
- | **`check_cache.py`** | Decodes a single cached latent file back into an image using the VAE to verify the preprocessing integrity. | `TARGET_FILE`, `VAE_ID` | Quick sanity check. |
 
 
138
 
139
- ### Inference Scripts
140
 
141
- | File | Purpose | Key Configs | Notes |
142
- | :--- | :--- | :--- | :--- |
143
- | **`inference_unified.py`** | **The main inference script.** Supports both **text-to-image** and **file-to-image** generation using either Euler or RK4 sampling. | `FILENAME` (for checkpoint), `INPUT_MODE`, `SAMPLER`, `GUIDANCE_SCALE`, `PROMPT`, `HEIGHT`, `WIDTH` | Recommended for all new generations. Integrates all necessary scaling factors. |
144
- | **`inference_euler.py`** | Archive: Old script for file-to-image inference with Euler sampling. | N/A | **DEPRECATED.** |
145
- | **`inference_rk4.py`** | Archive: Old script for file-to-image inference with RK4 sampling. | N/A | **DEPRECATED.** |
146
- | **`text_inference.py`** | Archive: Old script for text-to-image inference with Euler sampling. | N/A | **DEPRECATED.** |
 
2
  license: apache-2.0
3
  ---
4
 
 
5
 
6
+ # Single-Stream DiT with Global Fourier Filters (Proof-of-Concept)
7
 
8
+ This repository contains the codebase for a Single-Stream Diffusion Transformer (DiT) Proof-of-Concept, heavily inspired by modern architectures like **FLUX.1**, **Z-Image**, and **Lumina Image 2**.
9
 
10
+ The primary objective was to demonstrate the feasibility and training stability of coupling the high-fidelity **FLUX.1-VAE** with the powerful **T5Gemma2** text encoder for image generation on consumer-grade hardware (NVIDIA RTX 5060 Ti 16GB).
11
 
12
  **Note:** The entire project codebase can be found on the [GitHub](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2) page!
13
 
14
  ## Project Overview
15
 
16
  ### How it started
17
+
18
  ![How it started](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/bored.png?raw=true)
19
 
20
  ### Verification and Result Comparison
21
 
22
+ | Cached Latent Verification | Final Generated Sample (Euler 50 steps and CFG 3.0) |
23
  | :---: | :---: |
24
+ | ![Cache Verification](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/cache_verification.png?raw=true) | ![Generated Sample](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/sample_euler_steps50_cfg3.png?raw=true) |
25
 
26
  ## Core Models and Architecture
27
 
28
  | Component | Model ID / Function | Purpose |
29
  | :--- | :--- | :--- |
30
+ | **Generator** | `SingleStreamDiTV2` | Custom Single-Stream DiT featuring Visual Fusion blocks, Context Refiners, and Fourier Filters. DiT Parameters: _768 Hidden Size, 12 Heads, 16 Depth, 2 Refiner Depth, 128 Text Token Legth, 2 Patch Size._ |
31
+ | **Text Encoder** | `google/t5gemma-2-1b-1b` | Generates rich, 1152-dimensional text embeddings for high-quality semantic guidance. |
32
+ | **VAE** | `diffusers/FLUX.1-vae` | A 16-channel VAE with an 8x downsample factor, providing superior reconstruction for complex textures. |
33
+ | **Training Method** | Flow Matching (V-Prediction) | Optimized with a Velocity-based objective and an optional Self-Evaluation (Self-E) consistency loss. |
34
+
35
+ ## New in V3
36
+ - **Refinement Stages:** Separate noise and context refiner blocks to "prep" tokens before the joint fusion phase.
37
+ - **Fourier Filters:** Frequency-domain processing layers to improve global structural coherence.
38
+ - **Local Spatial Bias:** Conv2D-based depthwise biases to reinforce local texture within the transformer.
39
+ - **Rotary Embeddings (RoPE):** Dynamic 2D-RoPE grid support for area-preserving bucketing.
40
 
41
  ## Training Progression
42
 
43
+ | Early Epoch (Epoch 25) | Final Epoch (Epoch 1200) | Full Progression |
44
  | :---: | :---: | :---: |
45
+ | ![Epoch25](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/epoch_25.png?raw=true) | ![Epoch1700](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/epoch_1200.png?raw=true) | ![Epochs over time](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/training_progression.webp?raw=true) |
46
 
47
  ## Data Curation and Preprocessing
48
 
49
+ The model was tested on a curated dataset of **200 images** (10 categories of flowers) before scaling to larger datasets.
50
 
51
  | Component | Tool / Method | Purpose / Detail |
52
  | :--- | :--- | :--- |
53
+ | **Pre/Post-processing** | **[Dataset Helpers](https://github.com/Particle1904/DatasetHelpers)** | Used to resize images (using **[DPID](https://github.com/Mishini/dpid)** - Detail-Preserving Image Downscaling) and edit the Qwen3-VL captions. |
54
+ | **Captioning** | **Qwen3-VL-4B-Instruct** | Captions include precise botanical details: texture (waxy, serrated), plant anatomy (stamen, pistil), and camera lighting. |
55
+ | **Data Encoding** | `preprocess.py` | Encodes images via FLUX-VAE and text via T5Gemma2, applying aspect-ratio bucketing. |
56
 
57
  <details>
58
  <summary><h2><b>Qwen3-VL-4B-Instruct System Instruction (Captioning Prompt)</b></h2></summary>
 
88
  Output Format: Output a single string containing the caption, without double quotes, using commas to separate phrases.</i>
89
  </details>
90
 
91
+ ## Training History and Configuration
92
 
93
+ Training utilizes **8-bit AdamW** and a **Cosine Schedule with 5% Warmup** for 1200 (stopped early) epochs using **MSE**.
94
 
95
+ | Configuration | Value | Purpose |
96
+ | :--- | :--- | :--- |
97
+ | **Loss** | **`MSE at 2e-4`** | Trained with MSE only. |
98
+ | **Batch Size** | **`16`** | Gradient Checkpointing enabled and accumulative steps set to 2. |
99
+ | **Shift Value** | **`1.0` (Uniform)** | Ensures a balanced training across all noise levels, critical for learning geometry on small datasets. |
100
+ | **Latent Norm** | **`0.0 Mean / 1.0 Std`** | Hardcoded identity normalization to preserve the relative channel relationships of the FLUX VAE. **Note:** Using a Mean and Std calculated from the dataset resulted in poor reconstruction with artifacts. |
101
+ | **EMA Decay** | **`0.999`** | Maintains a moving average of weights for smoother, higher-quality inference. |
102
+ | **Self-Evolution** | **`Disabled`** | Optional teacher-student distillation. (**Note:** Not used in this PoC to maintain baseline architectural clarity). |
103
 
104
+ ### Loss & Fourier Gate Progression
105
 
106
+ | Loss Graph | Fourier Gate |
107
+ | :---: | :---: |
108
+ | ![Loss Graph](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/loss_curve.png?raw=true) | ![Fourier Gate](https://github.com/Particle1904/SingleStreamDiT_T5Gemma2/blob/main/readme_assets/fourier_gate.png?raw=true) |
109
 
110
  **Training Time Estimate:**
111
+ * **GPU Time:** Approximately **6 hours and 21 minutes** of total GPU compute time for 1200 epochs (RTX 5060 Ti 16GB).
112
+ * **Project Time (Human):** 13 days of R&D, including hyperparameter tuning.
113
 
114
  ## Reproducibility
115
 
116
  This repository is designed to be fully reproducible. The following data is included in the respective directories:
117
+ * **Raw Dataset:** The original `.png` images and the **Qwen3-VL-4B-Instruct** generated and reviewed `.txt` captions.
118
  * **Cached Dataset:** The processed, tokenized, and VAE-encoded latents (`.pt` files).
 
 
 
 
 
 
 
 
 
119
 
120
  ## Repository File Breakdown
121
 
122
+ ### Training & Core Scripts
 
 
123
 
124
+ | File | Purpose | Notes |
125
+ | :--- | :--- | :--- |
126
+ | **`train.py`** | Main training script. Supports EMA, Self-E, and Gradient Accumulation. | Includes automatic model compilation on Linux. |
127
+ | **`model.py`** | Defines `SingleStreamDiTV2` with Visual Fusion, Fourier Filters, and SwiGLU. | The core architecture definition. |
128
+ | **`config.py`** | Central configuration for paths, model dims, and hyperparameters. | All model settings are controlled here. |
129
+ | **`sanity_check.py`** | A utility to ensure the model can overfit to a single cached latent file. | Used for debugging architecture changes. |
130
 
131
  ### Utility & Preprocessing
132
 
133
+ | File | Purpose | Notes |
134
+ | :--- | :--- | :--- |
135
+ | **`preprocess.py`** | Prepares raw image/text data into cached `.pt` files using VAE and T5. | Run this before starting training. |
136
+ | **`calculate_cache_statistics.py`** | Analyzes cached latents to find Mean/Std for normalization settings. | **Note:** Use results with caution; defaults of 0.0/1.0 are often better. |
137
+ | **`debug_vae_pipeline.py`** | Tests the VAE reconstruction pipeline in float32 to isolate VAE issues. | Useful for troubleshooting color shifts. |
138
+ | **`check_cache.py`** | Decodes a single cached latent back to an image to verify preprocessing. | Fast integrity check. |
139
+ | **`generate_graph.py`** | Generates the loss curve visualization from the training CSV logs. | Creates `loss_curve.png`. |
140
 
141
+ ### Inference & Data
142
 
143
+ | File | Purpose | Notes |
144
+ | :--- | :--- | :--- |
145
+ | **`inferenceNotebook.ipynb`** | Primary inference tool. Supports text-to-image with Euler/RK4. | Best for interactive testing. |
146
+ | **`samplers.py`** | Numerical integration steps for Euler and Runge-Kutta 4 (RK4). | Logic for the flow matching inference. |
147
+ | **`latents.py`** | Scaling and normalization logic for VAE latents. | Shared across preprocess, train, and inference. |
148
+ | **`dataset.py`** | Bucket-batching and RAM-caching dataset implementation. | Handles the training data pipeline. |