badaoui HF Staff commited on
Commit
3e33548
·
verified ·
1 Parent(s): aa79869

Update optimum_neuron_export.py

Browse files
Files changed (1) hide show
  1. optimum_neuron_export.py +11 -10
optimum_neuron_export.py CHANGED
@@ -276,20 +276,21 @@ def export_transformer_model(model_id: str, task: str, folder: str, token: str)
276
  raise Exception(f"❌ Unsupported task: {task}. Supported: {supported}")
277
 
278
  inputs = get_default_inputs(task)
279
- compiler_configs = {"tensor_parallel_size": 4, "instance_type":"trn1"}
280
- yield f"🔧 Using default inputs: {inputs} with model_class : {model_class}"
281
 
282
  # Clear any old cache artifacts before export
283
- cache_base_dir = "/var/tmp/neuron-compile-cache"
284
-
285
  try:
286
  # Trigger the export/compilation
287
-
288
- export_kwargs = inputs | compiler_configs
289
- neuron_config = model_class.get_neuron_config(model_name_or_path=model_id, **export_kwargs)
290
- model = model_class.export(
291
- model_id=model_id, neuron_config=neuron_config
292
- )
 
 
293
 
294
  yield "✅ Export/compilation completed successfully."
295
 
 
276
  raise Exception(f"❌ Unsupported task: {task}. Supported: {supported}")
277
 
278
  inputs = get_default_inputs(task)
279
+ compiler_configs = {"auto_cast": "matmul", "auto_cast_type": "bf16", "instance_type": "inf2"}
280
+ yield f"🔧 Using default inputs: {inputs}"
281
 
282
  # Clear any old cache artifacts before export
283
+
 
284
  try:
285
  # Trigger the export/compilation
286
+ model = model_class.from_pretrained(
287
+ model_id,
288
+ export=True,
289
+ tensor_parallel_size=4,
290
+ token=token,
291
+ **compiler_configs,
292
+ **inputs,
293
+ )
294
 
295
  yield "✅ Export/compilation completed successfully."
296