andro / kaggle_job_quantize /quantize.py
krsnalyst's picture
Upload 19 files
5885a23 verified
import os
import subprocess
import sys
import threading
import time
import shutil
os.environ["CUDA_VISIBLE_DEVICES"] = ""
def log(msg):
print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)
def run_cmd(cmd):
log(f"Running: {cmd}")
subprocess.check_call(cmd, shell=True)
# Heartbeat thread
def heartbeat():
start = time.time()
while True:
time.sleep(60)
elapsed = int(time.time() - start)
log(f"HEARTBEAT: still alive after {elapsed}s")
t = threading.Thread(target=heartbeat, daemon=True)
t.start()
# ============================================================
# Step 1: Find the model.tflite from the previous kernel output
# ============================================================
log("Looking for model.tflite from previous kernel output...")
# Kaggle mounts kernel sources under /kaggle/input/<kernel-slug>/
input_base = "/kaggle/input"
tflite_path = None
for root, dirs, files in os.walk(input_base):
for f in files:
fpath = os.path.join(root, f)
size_mb = os.path.getsize(fpath) / (1024*1024)
log(f" Found: {fpath} ({size_mb:.1f} MB)")
if f.endswith(".tflite"):
tflite_path = fpath
if not tflite_path:
log("ERROR: No .tflite file found in input!")
log("Listing all input directories:")
for root, dirs, files in os.walk(input_base):
log(f" DIR: {root} ({len(files)} files)")
sys.exit(1)
size_gb = os.path.getsize(tflite_path) / (1024*1024*1024)
log(f"Found model: {tflite_path} ({size_gb:.2f} GB)")
# ============================================================
# Step 2: Install quantization dependencies
# ============================================================
log("Installing dependencies...")
run_cmd(f"{sys.executable} -m pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu")
run_cmd(f"{sys.executable} -m pip install -U litert-torch torchao transformers huggingface-hub mediapipe accelerate sentencepiece 'protobuf>=6.0'")
log("All dependencies installed.")
# ============================================================
# Step 3: Quantize the model
# ============================================================
log("=== QUANTIZING MODEL (dynamic_wi8_afp32) ===")
# Copy the tflite to working dir first (input is read-only)
work_tflite = "/kaggle/working/model.tflite"
log(f"Copying {tflite_path} -> {work_tflite}...")
shutil.copy2(tflite_path, work_tflite)
log("Copy done.")
# Also copy any other files from the previous output (embedder, tokenizer, etc.)
input_dir = os.path.dirname(tflite_path)
for f in os.listdir(input_dir):
src = os.path.join(input_dir, f)
dst = os.path.join("/kaggle/working", f)
if os.path.isfile(src) and src != tflite_path:
log(f"Copying {f}...")
shutil.copy2(src, dst)
from ai_edge_quantizer import quantizer as quant_lib
from ai_edge_quantizer import recipe as recipe_lib
log("Starting quantization...")
quantized_path = "/kaggle/working/model_quantized.tflite"
qt = quant_lib.Quantizer(work_tflite)
recipe = recipe_lib.dynamic_wi8_afp32()
qt.load_quantization_recipe(recipe)
log("Running quantization (this will take a while)...")
qt.quantize().export_model(quantized_path, overwrite=True)
size_gb = os.path.getsize(quantized_path) / (1024*1024*1024)
log(f"Quantized model saved: {quantized_path} ({size_gb:.2f} GB)")
# Remove unquantized copy to save space
os.remove(work_tflite)
log("Removed unquantized copy.")
# ============================================================
# Step 4: Bundle into .litertlm
# ============================================================
log("=== BUNDLING INTO .litertlm ===")
import litert_torch.generative.export_hf.export as export_lib
# Try to bundle - this might need the full export context
# If bundling fails, the quantized tflite is still the main output
try:
from litert_torch.generative.export_hf.core import bundle_utils
output_bundle = "/kaggle/working/gemma-4-E2B-it-uncensored.litertlm"
bundle_utils.bundle_litert_lm(
model_path=quantized_path,
output_path=output_bundle,
model_type="gemma4",
)
size_gb = os.path.getsize(output_bundle) / (1024*1024*1024)
log(f"Bundle saved: {output_bundle} ({size_gb:.2f} GB)")
except Exception as e:
log(f"Bundling failed (not critical): {e}")
log("The quantized .tflite is still available as output.")
# ============================================================
# Final: List all output files
# ============================================================
log("=== OUTPUT FILES ===")
for f in os.listdir("/kaggle/working"):
fpath = os.path.join("/kaggle/working", f)
if os.path.isfile(fpath):
size_mb = os.path.getsize(fpath) / (1024*1024)
log(f" {f} ({size_mb:.1f} MB)")
log("SUCCESS! Quantization complete.")