pixel_gen / src /models /conditioner /class_label.py
linxin02's picture
Upload lx_gan project
cef8b68 verified
Raw
History Blame Contribute Delete
436 Bytes
import torch
from src.models.conditioner.base import BaseConditioner
class LabelConditioner(BaseConditioner):
def __init__(self, num_classes):
super().__init__()
self.null_condition = num_classes
def _impl_condition(self, y, metadata):
return torch.tensor(y).long().cuda()
def _impl_uncondition(self, y, metadata):
return torch.full((len(y),), self.null_condition, dtype=torch.long).cuda()