bbkdevops's picture
download
raw
1.76 kB
"""
INT8 Dynamic Quantization — inference-time only, zero VRAM overhead ตอน train
เป้าหมาย: ใช้ VRAM น้อยลง ~50% ตอน serve โดยสูญเสีย accuracy < 1%
"""
import torch
import torch.nn as nn
from pathlib import Path
def quantize_model(model: nn.Module) -> nn.Module:
"""
Dynamic INT8 quantization บน Linear layers
- ไม่ต้องมี calibration dataset
- ทำงานบน CPU หรือ CUDA
- เร็วขึ้น ~2x บน CPU, ~1.3x บน GPU
"""
model.eval()
quantized = torch.quantization.quantize_dynamic(
model,
qconfig_spec={nn.Linear},
dtype=torch.qint8,
)
return quantized
def save_quantized(model: nn.Module, path: str | Path):
"""บันทึก quantized model"""
path = Path(path)
path.parent.mkdir(exist_ok=True)
torch.save(model.state_dict(), path)
size_mb = path.stat().st_size / 1024 / 1024
print(f"Quantized model saved → {path} ({size_mb:.1f} MB)")
def estimate_vram(model: nn.Module) -> str:
"""ประมาณ VRAM ที่ใช้"""
total_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
total_bytes += sum(b.numel() * b.element_size() for b in model.buffers())
mb = total_bytes / 1024 / 1024
return f"{mb:.0f} MB ({mb/1024:.2f} GB)"
if __name__ == "__main__":
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from model.config import small_config
from model.architecture import OmegaModel
cfg = small_config()
model = OmegaModel(cfg)
print(f"Before: {estimate_vram(model)}")
q = quantize_model(model)
print(f"After INT8: {estimate_vram(q)}")

Xet Storage Details

Size:
1.76 kB
·
Xet hash:
ca6d1023900afb155b24de7060e5922f16f7152680f718a477375e6aba85fb25

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.