File size: 449 Bytes
61029c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch

class _FunctionCorrelation(torch.autograd.Function):
    @staticmethod
    def forward(self, first, second):
        raise NotImplementedError()

def FunctionCorrelation(tenFirst, tenSecond):
    raise NotImplementedError()
    return _FunctionCorrelation.apply(tenFirst, tenSecond)

class ModuleCorrelation(torch.nn.Module):
    def __init__(self):
        raise NotImplementedError()
        super(ModuleCorrelation, self).__init__()