broadfield-dev commited on
Commit
6e5122c
·
verified ·
1 Parent(s): f813450

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -21
app.py CHANGED
@@ -5,13 +5,12 @@ import logging
5
  import time
6
  import tempfile
7
  import shutil
 
8
  from datetime import datetime
9
  from huggingface_hub import HfApi
10
  from transformers import AutoConfig, AutoModel, AutoTokenizer
11
- from optimum.onnxruntime import ORTQuantizer, ORTModelForCausalLM
12
  from optimum.onnxruntime.configuration import AutoQuantizationConfig
13
- # Use the unified optimum.main_export entrypoint
14
- from optimum.exporters.main import main_export
15
  import torch.nn.utils.prune as prune
16
 
17
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -62,14 +61,28 @@ def stage_2_prune_model(model, prune_percentage: float):
62
 
63
  def stage_3_4_onnx_quantize(model_path: str, calibration_data_path: str):
64
  log_stream = "[STAGE 3 & 4] Converting to ONNX and Quantizing...\n"
 
 
 
 
65
  try:
66
- run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
67
- model_name = os.path.basename(model_path)
68
- onnx_path = os.path.join(OUTPUT_DIR, f"{model_name}-{run_id}-onnx")
69
-
70
- main_export(model_path, output=onnx_path, task="auto", trust_remote_code=True)
 
 
 
 
 
71
  log_stream += f"Successfully exported base model to ONNX at: {onnx_path}\n"
72
-
 
 
 
 
 
73
  quantizer = ORTQuantizer.from_pretrained(onnx_path)
74
 
75
  if calibration_data_path:
@@ -93,28 +106,37 @@ def stage_3_4_onnx_quantize(model_path: str, calibration_data_path: str):
93
  log_stream += f"Successfully quantized model to: {quantized_path}\n"
94
  return quantized_path, log_stream
95
  except Exception as e:
96
- error_msg = f"Failed during ONNX conversion/quantization. Error: {e}"
97
  logging.error(error_msg, exc_info=True)
98
  raise RuntimeError(error_msg)
99
 
100
  def stage_3_4_gguf_quantize(model_id: str, quantization_strategy: str):
101
  log_stream = f"[STAGE 3 & 4] Converting to GGUF with '{quantization_strategy}' quantization...\n"
102
- try:
103
- run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
104
- model_name = model_id.replace('/', '_')
105
- gguf_path = os.path.join(OUTPUT_DIR, f"{model_name}-{run_id}-gguf")
106
- os.makedirs(gguf_path, exist_ok=True)
107
-
108
- main_export(model_id, output=os.path.join(gguf_path, "model.gguf"), export_format="gguf", quantization_strategy=quantization_strategy, trust_remote_code=True)
109
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  log_stream += f"Successfully exported and quantized model to GGUF at: {gguf_path}\n"
111
  return gguf_path, log_stream
112
- except Exception as e:
113
- error_msg = f"Failed during GGUF conversion. Error: {e}"
114
- logging.error(error_msg, exc_info=True)
115
  raise RuntimeError(error_msg)
116
 
117
-
118
  def stage_5_package_and_upload(model_id: str, optimized_model_path: str, pipeline_log: str, options: dict):
119
  log_stream = "[STAGE 5] Packaging and Uploading...\n"
120
  if not HF_TOKEN:
 
5
  import time
6
  import tempfile
7
  import shutil
8
+ import subprocess
9
  from datetime import datetime
10
  from huggingface_hub import HfApi
11
  from transformers import AutoConfig, AutoModel, AutoTokenizer
12
+ from optimum.onnxruntime import ORTQuantizer
13
  from optimum.onnxruntime.configuration import AutoQuantizationConfig
 
 
14
  import torch.nn.utils.prune as prune
15
 
16
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
61
 
62
  def stage_3_4_onnx_quantize(model_path: str, calibration_data_path: str):
63
  log_stream = "[STAGE 3 & 4] Converting to ONNX and Quantizing...\n"
64
+ run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
65
+ model_name = os.path.basename(model_path)
66
+ onnx_path = os.path.join(OUTPUT_DIR, f"{model_name}-{run_id}-onnx")
67
+
68
  try:
69
+ log_stream += "Executing `optimum-cli export onnx` via subprocess...\n"
70
+ export_command = [
71
+ "optimum-cli", "export", "onnx",
72
+ "--model", model_path,
73
+ "--trust-remote-code",
74
+ onnx_path
75
+ ]
76
+ process = subprocess.run(export_command, check=True, capture_output=True, text=True)
77
+ log_stream += process.stdout
78
+ if process.stderr: log_stream += f"[STDERR]\n{process.stderr}\n"
79
  log_stream += f"Successfully exported base model to ONNX at: {onnx_path}\n"
80
+ except subprocess.CalledProcessError as e:
81
+ error_msg = f"Failed during `optimum-cli export onnx`. Error:\n{e.stderr}"
82
+ logging.error(error_msg)
83
+ raise RuntimeError(error_msg)
84
+
85
+ try:
86
  quantizer = ORTQuantizer.from_pretrained(onnx_path)
87
 
88
  if calibration_data_path:
 
106
  log_stream += f"Successfully quantized model to: {quantized_path}\n"
107
  return quantized_path, log_stream
108
  except Exception as e:
109
+ error_msg = f"Failed during ONNX quantization step. Error: {e}"
110
  logging.error(error_msg, exc_info=True)
111
  raise RuntimeError(error_msg)
112
 
113
  def stage_3_4_gguf_quantize(model_id: str, quantization_strategy: str):
114
  log_stream = f"[STAGE 3 & 4] Converting to GGUF with '{quantization_strategy}' quantization...\n"
115
+ run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
116
+ model_name = model_id.replace('/', '_')
117
+ gguf_path = os.path.join(OUTPUT_DIR, f"{model_name}-{run_id}-gguf")
118
+ os.makedirs(gguf_path, exist_ok=True)
119
+ output_file = os.path.join(gguf_path, "model.gguf")
 
 
120
 
121
+ try:
122
+ log_stream += "Executing `optimum-cli export gguf` via subprocess...\n"
123
+ export_command = [
124
+ "optimum-cli", "export", "gguf",
125
+ "--model", model_id,
126
+ "--quantization_strategy", quantization_strategy,
127
+ "--trust-remote-code",
128
+ output_file
129
+ ]
130
+ process = subprocess.run(export_command, check=True, capture_output=True, text=True)
131
+ log_stream += process.stdout
132
+ if process.stderr: log_stream += f"[STDERR]\n{process.stderr}\n"
133
  log_stream += f"Successfully exported and quantized model to GGUF at: {gguf_path}\n"
134
  return gguf_path, log_stream
135
+ except subprocess.CalledProcessError as e:
136
+ error_msg = f"Failed during `optimum-cli export gguf`. Error:\n{e.stderr}"
137
+ logging.error(error_msg)
138
  raise RuntimeError(error_msg)
139
 
 
140
  def stage_5_package_and_upload(model_id: str, optimized_model_path: str, pipeline_log: str, options: dict):
141
  log_stream = "[STAGE 5] Packaging and Uploading...\n"
142
  if not HF_TOKEN: