Harmony18090's picture
Add source batch 2/11
76f9669 verified
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from compressed_tensors.transform import HadamardFactory, TransformFactory
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
from torch import device, dtype
from torch.nn import Parameter
@TransformFactory.register("random-hadamard")
class RandomHadamardFactory(HadamardFactory):
"""
Factory used to apply random hadamard transforms to a model
:param name: name associated with transform scheme
:param scheme: transform scheme which defines how transforms should be created
:param seed: random seed used to transform weight randomization
"""
def _create_weight(
self,
size: int,
device: device,
construct_device: device,
precision: dtype,
) -> Parameter:
data = random_hadamard_matrix(size, precision, construct_device, self.generator)
data = data.to(device=device)
return Parameter(data, requires_grad=self.scheme.requires_grad)