Generative_Inference_Faces / DIFFERENCES.md
Tahereh
Update to Generative Inference for Psychiatry Demo: add Noise stimulus, update parameters, fix model loading, and improve UI
420f791
# Differences Between Reference Code and Current Implementation
## Critical Differences Affecting Results
### 1. **First Iteration Handling** ⚠️ **CRITICAL**
**Reference Code:**
```python
if itr == 0:
# Don't add priors or diffusion noise to the first iteration
output = model(image_tensor)
# ... just get predictions, no gradient update
else:
# Calculate loss and gradients
if loss_infer == 'PGDD':
loss = torch.nn.functional.mse_loss(features, noisy_features)
grad = torch.autograd.grad(loss, image_tensor)[0]
adjusted_grad = inferstep.step(image_tensor, grad)
# ... apply gradient and noise
```
**Current Implementation:**
- **MISSING**: No check for `itr == 0` or `i == 0`
- Applies gradients and diffusion noise from the very first iteration
- This causes different starting behavior
### 2. **Model Extraction for PGDD**
**Reference Code:**
```python
new_model = extract_middle_layers(model.module, top_layer)
```
**Current Implementation:**
- Complex logic to handle Sequential models with normalizers
- Extracts from `model[1]` if Sequential, otherwise from `model`
- May handle DataParallel differently
### 3. **Gradient Calculation**
**Reference Code:**
```python
grad = torch.autograd.grad(loss, image_tensor)[0] # No retain_graph for PGDD
```
**Current Implementation:**
- Same for PGDD (no retain_graph)
- But uses `retain_graph=True` for IncreaseConfidence
### 4. **Normalization Handling**
**Reference Code:**
- Normalization is applied in the transform at the beginning
- `inference_normalization` controls whether transform includes normalization
- Model forward pass uses the already-normalized tensor
**Current Implementation:**
- Complex logic checking if model is Sequential with NormalizeByChannelMeanStd
- May apply normalization multiple times or inconsistently
- Different paths for sequential vs non-sequential models
### 5. **Variable Naming and Structure**
**Reference Code:**
- Uses `image_tensor` throughout the loop
- Directly modifies `image_tensor` with `requires_grad=True`
**Current Implementation:**
- Creates separate `x = image_tensor.clone().detach().requires_grad_(True)`
- Uses `x` in the loop instead of `image_tensor`
### 6. **Loss Function for IncreaseConfidence**
**Reference Code:**
```python
loss = calculate_loss(features, least_confident_classes[0], loss_function)
# Uses CrossEntropyLoss or MSELoss based on loss_function
```
**Current Implementation:**
```python
# Creates one-hot targets and uses MSE on softmax outputs
loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
```
- Different loss calculation method
- Uses MSE on softmax probabilities vs CrossEntropy on logits
### 7. **Diffusion Noise Application**
**Reference Code:**
```python
if itr == 0:
# Skip noise
else:
diffusion_noise = diffusion_noise_ratio * torch.randn_like(image_tensor).cuda()
if loss_infer == 'GradModulation':
image_tensor = inferstep.project(
image_tensor.clone() +
adjusted_grad * grad_modulation +
diffusion_noise * grad_modulation
)
else:
image_tensor = inferstep.project(
image_tensor.clone() + adjusted_grad + diffusion_noise
)
```
**Current Implementation:**
- Always applies diffusion noise (no `itr == 0` check)
- Applies noise in all iterations including the first
### 8. **Model Forward Pass in Loop**
**Reference Code:**
```python
if inference_config['misc_info'].get('smooth_inference', False):
# Smooth inference logic
else:
new_model.zero_grad()
features = new_model(image_tensor)
```
**Current Implementation:**
```python
x.grad = None # Instead of new_model.zero_grad()
if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None:
output = layer_model(x)
else:
output = model(x)
```
## Summary of Impact
1. **First iteration difference**: Most critical - reference skips gradient update on iteration 0
2. **Normalization**: Different application may cause numerical differences
3. **Loss calculation**: Different methods for IncreaseConfidence
4. **Model extraction**: May extract different layers due to Sequential handling
## Recommended Fixes
1. Add `if i == 0:` check to skip gradient update on first iteration
2. Simplify model extraction to match reference: `extract_middle_layers(model.module, top_layer)`
3. Align loss calculation for IncreaseConfidence with reference
4. Ensure normalization is applied consistently