File size: 3,787 Bytes
497c818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Minimal Code Package For My PixelDiT Three-Control Network

This folder contains extracted useful code from the current project. It is not just a prose document. It is a small Python package that implements the core innovations:

- independent `depth / seg / edge` control branches
- strict single-condition hard selection
- multi-condition layer-wise gated fusion
- DDP-safe mode sampling
- inactive branch gradient masking
- single-control and three-control dataset loading
- multi-condition cycle loss dispatch
- SoftCanny image-cycle edge consistency

## Files

```text
minimal_my_network/
  __init__.py
  independent_gated_control.py
  datasets.py
  losses.py
  README.md
```

## Core Model Code

Use:

```python
from minimal_my_network import IndependentBranchGatedFusion
```

Create fusion module:

```python
fusion = IndependentBranchGatedFusion(
    hidden_size=1536,
    num_layers=14,
    init_gate_logits=(0.5, 0.0, -0.5),
    control_structure_inject=(True, True, False),
    alpha_inject=2.0,
)
```

Inside a PixelDiT block loop, after you compute branch tokens:

```python
x = fusion(
    hidden=x,
    layer_idx=inject_idx,
    branch_tokens=[depth_tokens, seg_tokens, edge_tokens],
    keep_mask=control_keep,  # [B, 3]
    branch_structure_maps=[depth_struct, seg_struct, edge_struct],
)
```

Behavior is exactly:

```text
depth-only: uses depth branch only, gate ignored
seg-only: uses seg branch only, gate ignored
edge-only: uses edge branch only, gate ignored
multi-control: masked softmax gate over active branches only
```

## Training Utilities

```python
from minimal_my_network import (
    apply_multi_control_mode,
    sample_control_mode_ddp,
    mask_inactive_control_grads,
)
```

Sample one mode per step:

```python
mode = sample_control_mode_ddp(
    modes=("depth", "seg", "edge", "depth_seg", "depth_edge", "seg_edge", "depth_seg_edge"),
    probs=(0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.19),
    enable_dropout=True,
    device=device,
)
```

Apply sampled mode:

```python
control, control_keep = apply_multi_control_mode(control, mode, num_controls=3)
```

After backward:

```python
mask_inactive_control_grads(model, mode)
```

## Dataset Code

Three-control dataset:

```python
from minimal_my_network.datasets import PixelThreeControlDataset, subdir_range

ds = PixelThreeControlDataset(
    image_root="data/blip/extracted",
    depth_root="data/blip_depth_da3_nested_giant_large_1_1",
    seg_root="data/blip_sam2_large_extracted",
    edge_root="data/blip_edge",
    subdirs=subdir_range(0, 199),
)
```

Single-control dataset:

```python
from minimal_my_network.datasets import PixelSingleControlDataset, subdir_range

seg_ds = PixelSingleControlDataset(
    image_root="data/blip/extracted",
    control_root="data/blip_sam2_large_extracted",
    control_type="seg",
    subdirs=subdir_range(0, 199),
)
```

## Loss Code

```python
from minimal_my_network import MultiConditionCycleLoss, SoftCannyImagePyramidCycleLoss

edge_loss = SoftCannyImagePyramidCycleLoss(
    gaussian_kernel=11,
    threshold_min=0.2745,
    threshold_max=0.5882,
    temperature=0.03,
)

cycle = MultiConditionCycleLoss(
    depth_cycle_loss=depth_loss,
    seg_cycle_loss=seg_loss,
    edge_cycle_loss=edge_loss,
    depth_weight=1.0,
    seg_weight=1.0,
    edge_weight=1.0,
)
```

Call:

```python
loss = cycle(
    gen_image_m11,
    depth_01=depth,
    seg_01=seg,
    gt_image_m11=gt_image_m11,
    control_mode=mode,
)
```

## What Is Not Included

This folder intentionally does not copy the full PixelDiT backbone. You should keep using the original backbone from:

```text
pixdit_core/pixeldit.py
pixdit_core/pixeldit_t2i_control.py
```

This minimal package contains the transferable innovation code that Codex can reuse in another implementation.