Spaces:
Sleeping
Sleeping
Tahereh
Update to Generative Inference for Psychiatry Demo: add Noise stimulus, update parameters, fix model loading, and improve UI
420f791
A newer version of the Gradio SDK is available:
6.2.0
Differences Between Reference Code and Current Implementation
Critical Differences Affecting Results
1. First Iteration Handling ⚠️ CRITICAL
Reference Code:
if itr == 0:
# Don't add priors or diffusion noise to the first iteration
output = model(image_tensor)
# ... just get predictions, no gradient update
else:
# Calculate loss and gradients
if loss_infer == 'PGDD':
loss = torch.nn.functional.mse_loss(features, noisy_features)
grad = torch.autograd.grad(loss, image_tensor)[0]
adjusted_grad = inferstep.step(image_tensor, grad)
# ... apply gradient and noise
Current Implementation:
- MISSING: No check for
itr == 0ori == 0 - Applies gradients and diffusion noise from the very first iteration
- This causes different starting behavior
2. Model Extraction for PGDD
Reference Code:
new_model = extract_middle_layers(model.module, top_layer)
Current Implementation:
- Complex logic to handle Sequential models with normalizers
- Extracts from
model[1]if Sequential, otherwise frommodel - May handle DataParallel differently
3. Gradient Calculation
Reference Code:
grad = torch.autograd.grad(loss, image_tensor)[0] # No retain_graph for PGDD
Current Implementation:
- Same for PGDD (no retain_graph)
- But uses
retain_graph=Truefor IncreaseConfidence
4. Normalization Handling
Reference Code:
- Normalization is applied in the transform at the beginning
inference_normalizationcontrols whether transform includes normalization- Model forward pass uses the already-normalized tensor
Current Implementation:
- Complex logic checking if model is Sequential with NormalizeByChannelMeanStd
- May apply normalization multiple times or inconsistently
- Different paths for sequential vs non-sequential models
5. Variable Naming and Structure
Reference Code:
- Uses
image_tensorthroughout the loop - Directly modifies
image_tensorwithrequires_grad=True
Current Implementation:
- Creates separate
x = image_tensor.clone().detach().requires_grad_(True) - Uses
xin the loop instead ofimage_tensor
6. Loss Function for IncreaseConfidence
Reference Code:
loss = calculate_loss(features, least_confident_classes[0], loss_function)
# Uses CrossEntropyLoss or MSELoss based on loss_function
Current Implementation:
# Creates one-hot targets and uses MSE on softmax outputs
loss = loss + F.mse_loss(F.softmax(output, dim=1), one_hot)
- Different loss calculation method
- Uses MSE on softmax probabilities vs CrossEntropy on logits
7. Diffusion Noise Application
Reference Code:
if itr == 0:
# Skip noise
else:
diffusion_noise = diffusion_noise_ratio * torch.randn_like(image_tensor).cuda()
if loss_infer == 'GradModulation':
image_tensor = inferstep.project(
image_tensor.clone() +
adjusted_grad * grad_modulation +
diffusion_noise * grad_modulation
)
else:
image_tensor = inferstep.project(
image_tensor.clone() + adjusted_grad + diffusion_noise
)
Current Implementation:
- Always applies diffusion noise (no
itr == 0check) - Applies noise in all iterations including the first
8. Model Forward Pass in Loop
Reference Code:
if inference_config['misc_info'].get('smooth_inference', False):
# Smooth inference logic
else:
new_model.zero_grad()
features = new_model(image_tensor)
Current Implementation:
x.grad = None # Instead of new_model.zero_grad()
if config['loss_infer'] == 'Prior-Guided Drift Diffusion' and layer_model is not None:
output = layer_model(x)
else:
output = model(x)
Summary of Impact
- First iteration difference: Most critical - reference skips gradient update on iteration 0
- Normalization: Different application may cause numerical differences
- Loss calculation: Different methods for IncreaseConfidence
- Model extraction: May extract different layers due to Sequential handling
Recommended Fixes
- Add
if i == 0:check to skip gradient update on first iteration - Simplify model extraction to match reference:
extract_middle_layers(model.module, top_layer) - Align loss calculation for IncreaseConfidence with reference
- Ensure normalization is applied consistently