Spaces:
Sleeping
Sleeping
File size: 5,956 Bytes
9bf1d31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from optimizations.quantize import ModelQuantizer
import torch
import logging
import numpy as np
from dataclasses import dataclass
from typing import Dict, Any, Optional
import json
logger = logging.getLogger(__name__)
@dataclass
class ModelMetrics:
model_sizes: Dict[str, float]
inference_times: Dict[str, float]
comparison_metrics: Dict[str, Any]
class ModelHandler:
"""Base class for handling different types of models"""
def __init__(self, model_name, model_class, quantization_type, test_text=None):
self.model_name = model_name
self.model_class = model_class
self.quantization_type = quantization_type
self.test_text = test_text
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load models
self.original_model = self._load_original_model()
self.quantized_model = self._load_quantized_model()
self.metrics: Optional[ModelMetrics] = None
def _load_original_model(self):
"""Load the original model"""
model = self.model_class.from_pretrained(self.model_name)
return model.to(self.device)
def _load_quantized_model(self):
"""Load the quantized model using ModelQuantizer"""
model = ModelQuantizer.quantize_model(
self.model_class,
self.model_name,
self.quantization_type
)
if self.quantization_type not in ["4-bit", "8-bit"]:
model = model.to(self.device)
return model
@staticmethod
def _convert_to_serializable(obj):
"""Serialization for metrics"""
if isinstance(obj, np.generic):
return obj.item()
if isinstance(obj, (np.float32, np.float64)):
return float(obj)
if isinstance(obj, (np.int32, np.int64)):
return int(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, torch.Tensor):
return obj.cpu().numpy().tolist()
if isinstance(obj, dict):
return {k: ModelHandler._convert_to_serializable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [ModelHandler._convert_to_serializable(v) for v in obj]
return obj
def _format_metric_value(self, value):
"""Format metric value based on its type"""
if isinstance(value, (float, np.float32, np.float64)):
return f"{value:.8f}"
elif isinstance(value, (int, np.int32, np.int64)):
return str(value)
elif isinstance(value, list):
return "\n" + "\n".join([f" - {item}" for item in value])
elif isinstance(value, dict):
return "\n" + "\n".join([f" {k}: {v}" for k, v in value.items()])
else:
return str(value)
def run_inference(self, model, text):
"""Run model inference - to be implemented by subclasses"""
raise NotImplementedError
def decode_output(self, outputs):
"""Decode model outputs - to be implemented by subclasses"""
raise NotImplementedError
def compare(self):
"""Compare original and quantized models"""
try:
if self.test_text is None:
logger.warning("No test text provided. Skipping inference testing.")
return self.quantized_model
# Run inference
original_outputs, original_time = self.run_inference(self.original_model, self.test_text)
quantized_outputs, quantized_time = self.run_inference(self.quantized_model, self.test_text)
original_size = ModelQuantizer.get_model_size(self.original_model)
quantized_size = ModelQuantizer.get_model_size(self.quantized_model)
logger.info(f"Original Model Size: {original_size:.2f} MB")
logger.info(f"Quantized Model Size: {quantized_size:.2f} MB")
logger.info(f"Original Inference Time: {original_time:.4f} seconds")
logger.info(f"Quantized Inference Time: {quantized_time:.4f} seconds")
# Compare outputs
comparison_metrics = self.compare_outputs(original_outputs, quantized_outputs) or {}
for key, value in comparison_metrics.items():
comparison_metrics[key] = self._convert_to_serializable(value)
self.metrics = {
"model_sizes": {
"original": float(original_size),
"quantized": float(quantized_size)
},
"inference_times": {
"original": float(original_time),
"quantized": float(quantized_time)
},
"comparison_metrics": comparison_metrics
}
return self.quantized_model
except Exception as e:
logger.error(f"Quantization and comparison failed: {str(e)}")
raise e
def get_metrics(self) -> Dict[str, Any]:
"""Return the metrics dictionary"""
if self.metrics is None:
return {
"model_sizes": {"original": 0.0, "quantized": 0.0},
"inference_times": {"original": 0.0, "quantized": 0.0},
"comparison_metrics": {}
}
serializable_metrics = self._convert_to_serializable(self.metrics)
try:
json.dumps(serializable_metrics)
return serializable_metrics
except (TypeError, ValueError) as e:
logger.error(f"Error serializing metrics: {str(e)}")
return {
"model_sizes": {"original": 0.0, "quantized": 0.0},
"inference_times": {"original": 0.0, "quantized": 0.0},
"comparison_metrics": {}
}
|