badaoui HF Staff commited on
Commit
dc818fc
·
verified ·
1 Parent(s): a081d41

Update optimum_neuron_export.py

Browse files
Files changed (1) hide show
  1. optimum_neuron_export.py +29 -1
optimum_neuron_export.py CHANGED
@@ -116,6 +116,33 @@ def get_default_inputs(task_or_pipeline: str) -> Dict[str, int]:
116
  # Default to text-based shapes
117
  return {"batch_size": 1, "sequence_length": 128}
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
120
  try:
121
  discussions = api.get_repo_discussions(repo_id=model_id)
@@ -140,6 +167,7 @@ def export_and_git_add(model_id: str, task_or_pipeline: str, model_type: str, fo
140
  raise Exception(f"❌ Unsupported task/pipeline: {task_or_pipeline}. Supported: {supported}")
141
 
142
  inputs = get_default_inputs(task_or_pipeline)
 
143
  yield f"🔧 Using default inputs: {inputs}"
144
 
145
  try:
@@ -149,7 +177,7 @@ def export_and_git_add(model_id: str, task_or_pipeline: str, model_type: str, fo
149
  tensor_parallel_size=1,
150
  token=HF_TOKEN,
151
  cpu_backend=True,
152
- compiler_args="--target inf2"
153
  **inputs,
154
  )
155
  model.save_pretrained(folder)
 
116
  # Default to text-based shapes
117
  return {"batch_size": 1, "sequence_length": 128}
118
 
119
+ def prepare_compiler_flags(
120
+ auto_cast: str | None = None,
121
+ auto_cast_type: str = "bf16",
122
+ optlevel: str = "2",
123
+ instance_type: str = "trn1",
124
+ ):
125
+ if auto_cast is not None:
126
+ logger.info(f"Using Neuron: --auto-cast {auto_cast}")
127
+ auto_cast = "matmult" if auto_cast == "matmul" else auto_cast
128
+ compiler_args = ["--auto-cast", auto_cast]
129
+
130
+ logger.info(f"Using Neuron: --auto-cast-type {auto_cast_type}")
131
+ compiler_args.extend(["--auto-cast-type", auto_cast_type])
132
+ else:
133
+ compiler_args = ["--auto-cast", "none"]
134
+
135
+ compiler_args.extend(["--optlevel", optlevel])
136
+ logger.info(f"Using Neuron: --optlevel {optlevel}")
137
+
138
+ if instance_type == "trn2":
139
+ compiler_args.extend(["--target", "trn2"])
140
+ elif instance_type == "trn1":
141
+ compiler_args.extend(["--target", "trn1"])
142
+
143
+ compiler_args_str = " ".join(compiler_args)
144
+ return compiler_args_str
145
+
146
  def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
147
  try:
148
  discussions = api.get_repo_discussions(repo_id=model_id)
 
167
  raise Exception(f"❌ Unsupported task/pipeline: {task_or_pipeline}. Supported: {supported}")
168
 
169
  inputs = get_default_inputs(task_or_pipeline)
170
+ compiler_args = prepare_compiler_flags(auto_cast, auto_cast_type, optlevel, instance_type)
171
  yield f"🔧 Using default inputs: {inputs}"
172
 
173
  try:
 
177
  tensor_parallel_size=1,
178
  token=HF_TOKEN,
179
  cpu_backend=True,
180
+ compiler_args=compiler_args,
181
  **inputs,
182
  )
183
  model.save_pretrained(folder)