Spaces:
Running
Running
Update optimum_neuron_export.py
Browse files- optimum_neuron_export.py +5 -3
optimum_neuron_export.py
CHANGED
|
@@ -162,17 +162,17 @@ def export(model_id: str, task_or_pipeline: str, model_type: str, folder: str):
|
|
| 162 |
**inputs
|
| 163 |
)
|
| 164 |
neuron_model = NeuronModelForCausalLM.export(
|
| 165 |
-
model_id=model_id,
|
| 166 |
neuron_config=neuron_config,
|
| 167 |
token=HF_TOKEN,
|
| 168 |
)
|
| 169 |
-
neuron_model.save_pretrained(folder)
|
| 170 |
|
| 171 |
# DIFFUSION tasks
|
| 172 |
elif task_or_pipeline in DIFFUSION_PIPELINE_MAPPING:
|
| 173 |
model_class = DIFFUSION_PIPELINE_MAPPING.get(task_or_pipeline)
|
| 174 |
model = model_class.from_pretrained(model_id)
|
| 175 |
-
|
| 176 |
compiler_kwargs = {"auto_cast": "matmul", "auto_cast_type": "bf16"}
|
| 177 |
|
| 178 |
result = main_export(
|
|
@@ -182,10 +182,12 @@ def export(model_id: str, task_or_pipeline: str, model_type: str, folder: str):
|
|
| 182 |
torch_dtype= torch.bfloat16,
|
| 183 |
token=HF_TOKEN,
|
| 184 |
library_name=model_type,
|
|
|
|
| 185 |
cpu_backend=True,
|
| 186 |
model=model,
|
| 187 |
**inputs,
|
| 188 |
)
|
|
|
|
| 189 |
else:
|
| 190 |
raise ValueError(f"Unsupported task or pipeline: {task_or_pipeline}")
|
| 191 |
|
|
|
|
| 162 |
**inputs
|
| 163 |
)
|
| 164 |
neuron_model = NeuronModelForCausalLM.export(
|
| 165 |
+
model_id=model_id,
|
| 166 |
neuron_config=neuron_config,
|
| 167 |
token=HF_TOKEN,
|
| 168 |
)
|
| 169 |
+
neuron_model.save_pretrained(folder)
|
| 170 |
|
| 171 |
# DIFFUSION tasks
|
| 172 |
elif task_or_pipeline in DIFFUSION_PIPELINE_MAPPING:
|
| 173 |
model_class = DIFFUSION_PIPELINE_MAPPING.get(task_or_pipeline)
|
| 174 |
model = model_class.from_pretrained(model_id)
|
| 175 |
+
input_shapes = build_stable_diffusion_components_mandatory_shapes(**inputs)
|
| 176 |
compiler_kwargs = {"auto_cast": "matmul", "auto_cast_type": "bf16"}
|
| 177 |
|
| 178 |
result = main_export(
|
|
|
|
| 182 |
torch_dtype= torch.bfloat16,
|
| 183 |
token=HF_TOKEN,
|
| 184 |
library_name=model_type,
|
| 185 |
+
tensor_parallel_size=4,
|
| 186 |
cpu_backend=True,
|
| 187 |
model=model,
|
| 188 |
**inputs,
|
| 189 |
)
|
| 190 |
+
|
| 191 |
else:
|
| 192 |
raise ValueError(f"Unsupported task or pipeline: {task_or_pipeline}")
|
| 193 |
|