badaoui HF Staff commited on
Commit
e675f86
·
verified ·
1 Parent(s): 7eca9b3

Update optimum_neuron_export.py

Browse files
Files changed (1) hide show
  1. optimum_neuron_export.py +5 -3
optimum_neuron_export.py CHANGED
@@ -162,17 +162,17 @@ def export(model_id: str, task_or_pipeline: str, model_type: str, folder: str):
162
  **inputs
163
  )
164
  neuron_model = NeuronModelForCausalLM.export(
165
- model_id=model_id, # Fixed variable name
166
  neuron_config=neuron_config,
167
  token=HF_TOKEN,
168
  )
169
- neuron_model.save_pretrained(folder) # Fixed variable name
170
 
171
  # DIFFUSION tasks
172
  elif task_or_pipeline in DIFFUSION_PIPELINE_MAPPING:
173
  model_class = DIFFUSION_PIPELINE_MAPPING.get(task_or_pipeline)
174
  model = model_class.from_pretrained(model_id)
175
- # input_shapes = build_stable_diffusion_components_mandatory_shapes(**inputs)
176
  compiler_kwargs = {"auto_cast": "matmul", "auto_cast_type": "bf16"}
177
 
178
  result = main_export(
@@ -182,10 +182,12 @@ def export(model_id: str, task_or_pipeline: str, model_type: str, folder: str):
182
  torch_dtype= torch.bfloat16,
183
  token=HF_TOKEN,
184
  library_name=model_type,
 
185
  cpu_backend=True,
186
  model=model,
187
  **inputs,
188
  )
 
189
  else:
190
  raise ValueError(f"Unsupported task or pipeline: {task_or_pipeline}")
191
 
 
162
  **inputs
163
  )
164
  neuron_model = NeuronModelForCausalLM.export(
165
+ model_id=model_id,
166
  neuron_config=neuron_config,
167
  token=HF_TOKEN,
168
  )
169
+ neuron_model.save_pretrained(folder)
170
 
171
  # DIFFUSION tasks
172
  elif task_or_pipeline in DIFFUSION_PIPELINE_MAPPING:
173
  model_class = DIFFUSION_PIPELINE_MAPPING.get(task_or_pipeline)
174
  model = model_class.from_pretrained(model_id)
175
+ input_shapes = build_stable_diffusion_components_mandatory_shapes(**inputs)
176
  compiler_kwargs = {"auto_cast": "matmul", "auto_cast_type": "bf16"}
177
 
178
  result = main_export(
 
182
  torch_dtype= torch.bfloat16,
183
  token=HF_TOKEN,
184
  library_name=model_type,
185
+ tensor_parallel_size=4,
186
  cpu_backend=True,
187
  model=model,
188
  **inputs,
189
  )
190
+
191
  else:
192
  raise ValueError(f"Unsupported task or pipeline: {task_or_pipeline}")
193