""" Copyright (c) 2024 TU Munich Author: Nikita Araslanov License: Apache License 2.0 """ import os import sys import random import math import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as tf from util.ema_pytorch import EMA ####################################### # RAFT & SEARAFT & SMURF # ####################################### import importlib def try_import(module_name, class_path=None, alias=None): """ Try to import a module or class dynamically if its directory exists. Returns the imported module/class, or None if not available. """ base_dir = os.path.dirname(__file__) temp_path = None if module_name.startswith("RAFT"): temp_path = os.path.join(base_dir, "RAFT", "core") cleanup_roots = {"raft", "utils", "corr", "update", "extractor"} elif module_name.startswith("SEARAFT"): temp_path = os.path.join(base_dir, "SEARAFT", "core") cleanup_roots = {"raft", "utils", "corr", "update", "extractor"} elif module_name.startswith("SMURF"): temp_path = os.path.join(base_dir, "SMURF") cleanup_roots = set() else: cleanup_roots = set() module_dir = os.path.join(base_dir, module_name) display_name = alias or module_name if not os.path.isdir(module_dir): print(f"[Warning] {display_name} directory not found. Skipping import.") return None inserted = False before_modules = set(sys.modules.keys()) try: if temp_path and temp_path not in sys.path: sys.path.insert(0, temp_path) inserted = True if class_path: class_path = class_path.replace("/", ".").replace("-", "_") components = class_path.split(".") mod = importlib.import_module(".".join(components[:-1])) return getattr(mod, components[-1]) else: mod_name = module_name.replace("-", "_") return importlib.import_module(mod_name) except Exception as e: print(f"[Warning] {display_name} import failed: {e}") return None finally: if inserted: try: sys.path.remove(temp_path) except ValueError: pass # Clean up generic module names leaked during temporary import. leaked = set(sys.modules.keys()) - before_modules for mod_name in leaked: root = mod_name.split(".", 1)[0] if root in cleanup_roots: sys.modules.pop(mod_name, None) # -------------------------------------------------------------------- # Trying to import # -------------------------------------------------------------------- # Path to submodule root (e.g., SEARAFT) # SUBMODULES # --- RAFT --- RAFT = try_import("RAFT", "RAFT.core.raft.RAFT", alias="RAFT") InputPadder = try_import("RAFT", "RAFT.core.utils.utils.InputPadder", alias="RAFT InputPadder") # --- SEA-RAFT --- #SEARAFT = try_import("SEARAFT", "SEARAFT.core.raft.RAFT", alias="SEARAFT") # --- SMURF --- SMURF_RAFT = try_import("SMURF", "SMURF.smurf.raft_smurf", alias="SMURF") class RAFT_Args: model = "models/raft-sintel.pth" step = 1 small = False mixed_precision = False alternate_corr = False def __contains__(self, value): return hasattr(self, value) class RaftFlow: def __init__(self, denorm_func): if RAFT is None: raise ImportError( "RAFT not available. Ensure `RAFT/` exists and is importable." ) args = RAFT_Args() model = nn.DataParallel(RAFT(args)) model.load_state_dict(torch.load(args.model)) self.model = model.module.eval() self.denorm = denorm_func @torch.no_grad() def __call__(self, image1, image2): image1_255 = self.denorm(image1) * 255. image2_255 = self.denorm(image2) * 255. padder = InputPadder(image1_255.shape) image1_pad, image2_pad = padder.pad(image1_255, image2_255) _, flow = self.model(image1_pad, image2_pad, iters=20, test_mode=True) flow = padder.unpad(flow) flow[:, 0, :, :] *= 2 / flow.shape[3] flow[:, 1, :, :] *= 2 / flow.shape[2] return flow class SeaRAFTArgs: model_url = "MemorySlices/Tartan-C-T-TSKH-spring540x960-M" name = "spring-M" dataset= "spring" gpus = [0, 1, 2, 3, 4, 5, 6, 7] use_var= True var_min= 0 var_max= 10 pretrain= "resnet34" initial_dim= 64 block_dims =[64, 128, 256] radius= 4 dim= 128 num_blocks= 2 iters= 4 def __contains__(self, value): return hasattr(self, value) class SeaFlow: def __init__(self, denorm_func): if SEARAFT is None: raise ImportError( "SEARAFT not available. Ensure `SEARAFT/` exists and is importable." ) print('Initializing Sea-RAFT') args = SeaRAFTArgs() model = SEARAFT.from_pretrained(args.model_url, args=args, force_download=True) # Use Hugging Face model repository self.model = model.eval() self.denorm = denorm_func @torch.no_grad() def __call__(self, image1, image2): image1_255 = self.denorm(image1) * 255. image2_255 = self.denorm(image2) * 255. padder = InputPadder(image1_255.shape) image1_pad, image2_pad = padder.pad(image1_255, image2_255) output = self.model(image1_pad, image2_pad, iters=SeaRAFTArgs.iters, test_mode=True) flow = output['flow'][-1] flow = padder.unpad(flow) flow[:, 0, :, :] *= 2 / flow.shape[3] flow[:, 1, :, :] *= 2 / flow.shape[2] return flow class SMURFArgs: #checkpoint = "SMURF/models/smurf-sintel/smurf-sintel.pt" #checkpoint = "SMURF/models/smurf-chairs/smurf-chairs.pt" checkpoint = "SMURF/models/smurf-kitti/smurf-kitti.pt" def __contains__(self, value): return hasattr(self, value) class SMURF: def __init__(self, denorm_func): if SMURF_RAFT is None: raise ImportError( "SMURF not available. Ensure `SMURF/` exists and is importable." ) args = SMURFArgs() print(f'Initialize SMURF / {args.checkpoint}') model = SMURF_RAFT(checkpoint=args.checkpoint) self.model = model.eval() self.denorm = denorm_func @torch.no_grad() def __call__(self, image1, image2): image1_norm = 2.0 * self.denorm(image1) - 1.0 image2_norm = 2.0 * self.denorm(image2) - 1.0 #padder = InputPadder(image1_255.shape) #image1_pad, image2_pad = padder.pad(image1_255, image2_255) output = self.model(image1_norm, image2_norm) # last iteration flow = output[-1] flow[:, 0, :, :] *= 2 / flow.shape[3] flow[:, 1, :, :] *= 2 / flow.shape[2] return flow ###################### # Models # ###################### # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- from functools import partial from timm.models.vision_transformer import PatchEmbed, Block from util.pos_embed import get_2d_sincos_pos_embed class MAEEncoder(nn.Module): """ Masked Autoencoder (encoder only) with VisionTransformer backbone """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): super().__init__() # -------------------------------------------------------------------------- # MAE encoder specifics self.embed_dim = embed_dim self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding self.blocks = nn.ModuleList([ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) # -------------------------------------------------------------------------- def interpolate_pos_encoding(self, x, w, h): npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed class_pos_embed = self.pos_embed[:, 0] patch_pos_embed = self.pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_embed.patch_size h0 = h // self.patch_embed.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode='bicubic', ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) def forward(self, imgs): B, nc, w, h = imgs.shape # embed patches x = self.patch_embed(imgs) # add pos embed w/o cls token x = x + self.pos_embed[:, 1:, :] #self.interpolate_pos_encoding(imgs, w, h) # append cls token cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) # apply Transformer blocks for blk in self.blocks: x = blk(x) x = self.norm(x) return x def mae_vit_base_encoder(cfg): from timm.layers import resample_abs_pos_embed # by default model = MAEEncoder(img_size=cfg.input_size, patch_size=cfg.patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6)) print("MAE: output token grid size", model.patch_embed.grid_size) model_weights = torch.load(cfg.enc_snapshot)["model"] model_weights["pos_embed"] = resample_abs_pos_embed(model_weights["pos_embed"], new_size=model.patch_embed.grid_size, num_prefix_tokens=1, interpolation='bicubic', antialias=True, verbose=True) model.load_state_dict(model_weights, strict=True) return model def mae_vit_large_encoder(cfg): from timm.layers import resample_abs_pos_embed model = MAEEncoder(img_size=cfg.input_size, patch_size=cfg.patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6)) print("MAE: output token grid size", model.patch_embed.grid_size) model_weights = torch.load(cfg.enc_snapshot)["model"] model_weights["pos_embed"] = resample_abs_pos_embed(model_weights["pos_embed"], new_size=model.patch_embed.grid_size, num_prefix_tokens=1, interpolation='bicubic', antialias=True, verbose=True) model.load_state_dict(model_weights, strict=True) return model class CfgDPT(): def __init__(self, patch_size, features, fdim, hooks=[2, 5, 8, 11]): self.patch_size = patch_size self.features = features self.vit_features = fdim self.hooks = hooks def load_encoder(cfg): if cfg.enc_snapshot.startswith("dino_"): enc = torch.hub.load('facebookresearch/dino:main', cfg.enc_snapshot) if cfg.enc_snapshot.endswith("vits16"): return enc, CfgDPT(16, 4*[384], 384) elif cfg.enc_snapshot.endswith("vitb16"): return enc, CfgDPT(16, 4*[768], 768) elif cfg.enc_snapshot.endswith("vitl16"): return enc, CfgDPT(16, 4*[1024], 1024) else: return enc, None print("Did not find DINOv1 ", cfg.enc_snapshot) elif cfg.enc_snapshot.startswith("dinov2_"): enc = torch.hub.load('facebookresearch/dinov2', cfg.enc_snapshot) if cfg.enc_snapshot.endswith("vits14"): return enc, CfgDPT(14, 4*[384], 384) elif cfg.enc_snapshot.endswith("vitb14"): return enc, CfgDPT(14, 4*[768], 768) elif cfg.enc_snapshot.endswith("vitl14"): return enc, CfgDPT(14, 4*[1024], 1024, [5, 11, 17, 23]) print("Did not find DINOv2 ", cfg.enc_snapshot) elif cfg.enc_snapshot.endswith("mae_pretrain_vit_base.pth"): print("Using MAE ViT-B") return mae_vit_base_encoder(cfg), \ CfgDPT(16, [96, 192, 384, 768], 768) elif cfg.enc_snapshot.endswith("mae_pretrain_vit_large.pth"): print("Using MAE ViT-L") return mae_vit_large_encoder(cfg), \ CfgDPT(16, [256, 512, 1024, 1024], 1024, [5, 11, 17, 23]) raise NotImplemented ###################### # DPT # ###################### class SkipCLS(nn.Module): def __init__(self, start_index=1): super(SkipCLS, self).__init__() self.start_index = start_index def forward(self, x): return x[:, self.start_index :] class Transpose(nn.Module): def __init__(self, dim0, dim1): super(Transpose, self).__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, x): x = x.transpose(self.dim0, self.dim1) return x class Postprocess(nn.Module): def __init__(self, patch_size, *conv_layers): super().__init__() self.patch_size = patch_size self.pre = nn.Sequential(SkipCLS(), Transpose(1, 2)) self.post = nn.Sequential(*conv_layers) def forward(self, x, hw): x = self.pre(x) x = x.unflatten(2, (hw[0] // self.patch_size, \ hw[1] // self.patch_size)) return self.post(x) def save_activation(model, name): assert not hasattr(model, name), f"Model already has attribute {name}" def hook(module, input, output): setattr(model, name, output) return hook def dpt_wrapper(model, hooks = [2, 5, 8, 11]): # adding hooks model.blocks[hooks[0]].register_forward_hook(save_activation(model, "layer_1")) model.blocks[hooks[1]].register_forward_hook(save_activation(model, "layer_2")) model.blocks[hooks[2]].register_forward_hook(save_activation(model, "layer_3")) model.blocks[hooks[3]].register_forward_hook(save_activation(model, "layer_4")) from util.dpt_blocks import ( _make_scratch, FeatureFusionBlock, Interpolate, LayerNormBCHW, ) class DecodeDPT(nn.Module): """Network for monocular depth estimation.""" def __init__(self, cfg, features_out=256, features_final=128, non_negative=True): super(DecodeDPT, self).__init__() self.scratch = _make_scratch(cfg.features, features_out) self.act_postprocess1 = Postprocess(cfg.patch_size, nn.Conv2d( in_channels=cfg.vit_features, out_channels=cfg.features[0], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=cfg.features[0], out_channels=cfg.features[0], kernel_size=4, stride=4, padding=0, bias=True, dilation=1, groups=1, ) ) self.act_postprocess2 = Postprocess(cfg.patch_size, nn.Conv2d( in_channels=cfg.vit_features, out_channels=cfg.features[1], kernel_size=1, stride=1, padding=0, ), nn.ConvTranspose2d( in_channels=cfg.features[1], out_channels=cfg.features[1], kernel_size=2, stride=2, padding=0, bias=True, dilation=1, groups=1, ) ) self.act_postprocess3 = Postprocess(cfg.patch_size, nn.Conv2d( in_channels=cfg.vit_features, out_channels=cfg.features[2], kernel_size=1, stride=1, padding=0, ) ) self.act_postprocess4 = Postprocess(cfg.patch_size, nn.Conv2d( in_channels=cfg.vit_features, out_channels=cfg.features[3], kernel_size=1, stride=1, padding=0, ), nn.Conv2d( in_channels=cfg.features[3], out_channels=cfg.features[3], kernel_size=3, stride=2, padding=1, ) ) self.scratch.refinenet4 = FeatureFusionBlock(features_out) self.scratch.refinenet3 = FeatureFusionBlock(features_out) self.scratch.refinenet2 = FeatureFusionBlock(features_out) self.scratch.refinenet1 = FeatureFusionBlock(features_out) self.scratch.output_conv0 = nn.Conv2d(features_out, features_out, kernel_size=3, stride=1, padding=1) self.scratch.output_conv1 = nn.Sequential( nn.Conv2d(features_out, features_out, kernel_size=3, stride=1, padding=1), nn.ReLU(True), nn.Conv2d(features_out, features_final, kernel_size=1, stride=1, padding=0) ) self.norm = LayerNormBCHW(features_final) def forward(self, enc, hw, with_norm=True): layer_1 = self.act_postprocess1(enc.layer_1, hw) layer_2 = self.act_postprocess2(enc.layer_2, hw) layer_3 = self.act_postprocess3(enc.layer_3, hw) layer_4 = self.act_postprocess4(enc.layer_4, hw) layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) out = self.scratch.output_conv0(path_1) out = F.interpolate(out, hw, mode="bilinear", align_corners=False) out = self.scratch.output_conv1(out) #out = torch.squeeze(out, dim=1) if with_norm: out = self.norm(out) return out ###################### # FlowFeat # ###################### class FlowFeat(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.encoder, dpt_cfg = load_encoder(cfg) # encoder dpt_wrapper(self.encoder, dpt_cfg.hooks) # decoder self.decoder = DecodeDPT(dpt_cfg, cfg.fdim, cfg.fdim) @torch.no_grad() def forward_up(self, x): b,c,h,w = x.shape hh = h // self.cfg.patch_size ww = w // self.cfg.patch_size # hooks self.encoder(x) x_enc = self.encoder.layer_4[:, 1:] # skipping cls token x_enc = self.encoder.norm(x_enc) x_enc = x_enc.movedim(1, -1).view(b, -1, hh, ww) x = self.decoder(self.encoder, (h, w)) return x_enc, x def forward(self, x): return self.forward_up(x) class FlowFeatTrain(FlowFeat): def __init__(self, cfg): super().__init__(cfg) self.fdim = cfg.fdim self.ridge_alpha = cfg.ridge_alpha self.input_size = cfg.input_size assert cfg.input_size[0] % cfg.patch_size == 0, "Height is not divisible by patch size" assert cfg.input_size[1] % cfg.patch_size == 0, "Wideht is not divisible by patch size" self.flow_loss = getattr(self, "flow_" + cfg.flow_loss) self.edge_loss = getattr(self, "edge_" + cfg.edge_loss) # EMA decoder self.decoder_ema = EMA(self.decoder, beta = cfg.decoder_momentum, update_after_step = 1, update_every = cfg.decoder_update_every) self.denorm = tf.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225]) self.norm = tf.Normalize(mean=[0.485, 0.456, 0.406], std =[0.229, 0.224, 0.225]) self.flow = globals()[cfg.flownet](self.denorm) def parameter_groups(self): return [{"name": "decoder", "params": self.decoder.parameters()}] def flow_l1(self, pred_flow, teach_flow, **kwargs): l1_dist = torch.abs(pred_flow - teach_flow).sum(1, keepdim=True) return l1_dist.mean() def flow_l2(self, pred_flow, teach_flow, **kwargs): return F.mse_loss(pred_flow, teach_flow) def flow_l1smooth(self, pred_flow, teach_flow, beta=1., **kwargs): loss = F.smooth_l1_loss(pred_flow, teach_flow, beta=beta) return loss.mean() def flow_l1huber(self, pred_flow, teach_flow, beta=1., **kwargs): loss = F.huber_loss(pred_flow, teach_flow, delta=beta) return loss.mean() def edge_l1(self, x, y, sigma=1.0): return (1. - torch.exp(-y / sigma)) * torch.abs(x - y) def edge_l1norm(self, x, y, sigma=1.0): w = 1. - torch.exp(-y / sigma) w = w / w.sum((-1, -2), keepdim=True) return (w * torch.abs(x - y)).sum((-1, -2)) def edge_l2norm(self, x, y, sigma=1.0): w = 1. - torch.exp(-y / sigma) w = w / w.sum((-1, -2), keepdim=True) return (w * (x - y)**2).sum((-1, -2)) def edge_l1smooth(self, x, y, sigma=1.0): return F.smooth_l1_loss(x, y, beta=sigma) def edge_l1huber(self, x, y, sigma=1.0): return F.huber_loss(x, y, delta=sigma) def flow_boundary_loss(self, gt_flow, pred_flow, sigma, eps=1e-5): """Computes flow boundary loss""" grad_x = lambda x: torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]) grad_y = lambda x: torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]) pred_dx = grad_x(pred_flow) pred_dy = grad_y(pred_flow) gt_dx = grad_x(gt_flow) gt_dy = grad_y(gt_flow) loss_dx = self.edge_loss(pred_dx, gt_dx, sigma) loss_dy = self.edge_loss(pred_dy, gt_dy, sigma) return loss_dx.mean() + loss_dy.mean() def update_ema(self): self.decoder_ema.update() def forward_flow(self, flow, rX_ema, rX, alpha=0.1, mask_ratio=0.75): to_mat = lambda x: x.flatten(2, 3).movedim(1, -1) add_one = lambda x: torch.cat([x, torch.ones_like(x[..., :1])], -1) b,_,H,W = rX.shape flow_mat = to_mat(flow) d = flow_mat.shape[-1] X_ema = add_one(to_mat(rX_ema)) f = X_ema.shape[-1] lhs = X_ema.transpose(-1, -2) @ X_ema / X_ema.shape[1] if alpha > 0.: lhs += alpha * torch.eye(f)[None, ...].expand(b, -1, -1).type_as(lhs) rhs = X_ema.transpose(-1, -2) @ flow_mat / X_ema.shape[1] A, res, rank, svs = torch.linalg.lstsq(lhs, rhs) pred_flow = add_one(to_mat(rX)) @ A pred_flow = pred_flow.movedim(1, -1).view(b, d, H, W) return pred_flow, A def forward_enc(self, crop1, crop2): b,c,h,w = crop1.shape with torch.no_grad(): self.encoder(crop1) xT = self.decoder_ema(self.encoder, (h, w)) # student with torch.no_grad(): self.encoder(crop2) xS = self.decoder(self.encoder, (h, w)) return xS, xT def crop_view(self, frame, params, input_size): b = frame.shape[0] affine_grid = F.affine_grid(params, (b, 1, input_size[0], input_size[1]), align_corners=False) frame_crop = F.grid_sample(frame, affine_grid, align_corners=False) return frame_crop, affine_grid def forward(self, frames, frame0, params1, params2, epoch=0.): """ frames: [B, T, 3, H, W] """ ### compute teacher flow # flow -> crop1 and crop2 teacher_flow = self.flow(frames[:, 0], frames[:, 1]) # normalising the flow if self.cfg.norm_flow: flow_mean = teacher_flow.mean((2, 3), keepdim=True) flow_std = teacher_flow.std((2, 3), keepdim=True) teacher_flow = (teacher_flow - flow_mean) / (flow_std + 1e-5) ### main ### b,T = frames.shape[:2] crop1, affine_grid1 = self.crop_view(frame0, params1, self.cfg.input_size) crop2, affine_grid2 = self.crop_view(frame0, params2, self.cfg.input_size) features, features_ema = self.forward_enc(crop1, crop2) outs = {} losses = {} teacher_flow1 = F.grid_sample(teacher_flow, affine_grid1, align_corners=False) teacher_flow2 = F.grid_sample(teacher_flow, affine_grid2, align_corners=False) # student flow student_flow, A = self.forward_flow(teacher_flow1, features_ema, features, self.ridge_alpha) losses["flowres"] = self.flow_loss(student_flow, teacher_flow2, beta=self.cfg.flow_beta) losses["flowbdr"] = self.flow_boundary_loss(teacher_flow2, student_flow, self.cfg.flow_edge_sigma) losses["total"] = 0. losses["total"] += self.cfg.flow_weight * losses["flowres"] losses["total"] += self.cfg.flow_edge_weight * losses["flowbdr"] tag = "flow_0" outs[tag] = student_flow.movedim(1, -1)[..., :2] outs["t_" + tag] = teacher_flow2.movedim(1, -1) outs["features"] = features outs["crop1"] = crop1 outs["crop2"] = crop2 return losses, outs