| import os | |
| import sys | |
| import torch | |
| import torch.nn as nn | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| from trellis2.quantization import quantize_model | |
| def main(): | |
| m = nn.Sequential( | |
| nn.Linear(16, 8), | |
| nn.ReLU(), | |
| nn.Linear(8, 4) | |
| ) | |
| print('model loaded') | |
| quantize_model(m, bits=4, dtype=torch.float16) | |
| print('model quantized') | |
| x = torch.randn(2, 16) | |
| y = m(x) | |
| print('forward ok') | |
| print('output shape:', y.shape) | |
| print('output dtype:', y.dtype) | |
| print('output sample:', y[0].tolist()) | |
| if __name__ == '__main__': | |
| main() | |
| import torch | |
| print("allocated GB:", torch.cuda.memory_allocated() / 1024**3) | |
| print("reserved GB:", torch.cuda.memory_reserved() / 1024**3) | |