Felladrin commited on
Commit
548f210
·
1 Parent(s): 320c9ad

Allow users to decide if they want to base the 'task' argument from the

Browse files

conversion script on the model's pipeline tag

Because some models (e.g. "Norm/nougat-latex-base") can only be
converted with this option enabled.

Files changed (1) hide show
  1. app.py +22 -0
app.py CHANGED
@@ -309,6 +309,7 @@ class ModelConverter:
309
  input_model_id: str,
310
  trust_remote_code: bool = False,
311
  output_attentions: bool = False,
 
312
  ) -> Tuple[bool, Optional[str]]:
313
  """Convert a Hugging Face model to ONNX format.
314
 
@@ -316,6 +317,7 @@ class ModelConverter:
316
  input_model_id: Hugging Face model repository ID
317
  trust_remote_code: Whether to trust and execute remote code from the model
318
  output_attentions: Whether to output attention weights (required for some tasks)
 
319
 
320
  Returns:
321
  Tuple containing:
@@ -337,6 +339,18 @@ class ModelConverter:
337
  if output_attentions:
338
  conversion_args.append("--output_attentions")
339
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  # Run the conversion
341
  result = self._run_conversion_subprocess(
342
  input_model_id, extra_args=conversion_args or None
@@ -521,6 +535,13 @@ def main():
521
  "Whether to output attentions from the Whisper model. This is required for word-level (token) timestamps."
522
  )
523
 
 
 
 
 
 
 
 
524
  # Determine output repository
525
  # If user owns the model, allow uploading to the same repo
526
  if config.hf_username == input_model_id.split("/")[0]:
@@ -559,6 +580,7 @@ def main():
559
  input_model_id,
560
  trust_remote_code=trust_remote_code,
561
  output_attentions=output_attentions,
 
562
  )
563
  if not success:
564
  st.error(f"Conversion failed: {stderr}")
 
309
  input_model_id: str,
310
  trust_remote_code: bool = False,
311
  output_attentions: bool = False,
312
+ enable_task_inference: bool = True,
313
  ) -> Tuple[bool, Optional[str]]:
314
  """Convert a Hugging Face model to ONNX format.
315
 
 
317
  input_model_id: Hugging Face model repository ID
318
  trust_remote_code: Whether to trust and execute remote code from the model
319
  output_attentions: Whether to output attention weights (required for some tasks)
320
+ enable_task_inference: Whether to pass the task argument to the conversion script based on the model's pipeline tag
321
 
322
  Returns:
323
  Tuple containing:
 
339
  if output_attentions:
340
  conversion_args.append("--output_attentions")
341
 
342
+ if enable_task_inference:
343
+ try:
344
+ info = model_info(
345
+ repo_id=input_model_id, token=self.config.hf_token
346
+ )
347
+ pipeline_tag = getattr(info, "pipeline_tag", None)
348
+ task = self._normalize_pipeline_tag(pipeline_tag)
349
+ if task:
350
+ conversion_args.extend(["--task", task])
351
+ except Exception:
352
+ pass
353
+
354
  # Run the conversion
355
  result = self._run_conversion_subprocess(
356
  input_model_id, extra_args=conversion_args or None
 
535
  "Whether to output attentions from the Whisper model. This is required for word-level (token) timestamps."
536
  )
537
 
538
+ # Optional: Task inference toggle
539
+ enable_task_inference = st.toggle(
540
+ "Optional: Base the 'task' argument from the conversion script on the model's pipeline tag",
541
+ value=False,
542
+ help="This can make the conversion of some models work, but may cause issues for others. It's recommended to first try converting the model with this option disabled, and only enable it if the conversion fails.",
543
+ )
544
+
545
  # Determine output repository
546
  # If user owns the model, allow uploading to the same repo
547
  if config.hf_username == input_model_id.split("/")[0]:
 
580
  input_model_id,
581
  trust_remote_code=trust_remote_code,
582
  output_attentions=output_attentions,
583
+ enable_task_inference=enable_task_inference,
584
  )
585
  if not success:
586
  st.error(f"Conversion failed: {stderr}")