Allow users to decide if they want to base the 'task' argument from the
Browse filesconversion script on the model's pipeline tag
Because some models (e.g. "Norm/nougat-latex-base") can only be
converted with this option enabled.
app.py
CHANGED
|
@@ -309,6 +309,7 @@ class ModelConverter:
|
|
| 309 |
input_model_id: str,
|
| 310 |
trust_remote_code: bool = False,
|
| 311 |
output_attentions: bool = False,
|
|
|
|
| 312 |
) -> Tuple[bool, Optional[str]]:
|
| 313 |
"""Convert a Hugging Face model to ONNX format.
|
| 314 |
|
|
@@ -316,6 +317,7 @@ class ModelConverter:
|
|
| 316 |
input_model_id: Hugging Face model repository ID
|
| 317 |
trust_remote_code: Whether to trust and execute remote code from the model
|
| 318 |
output_attentions: Whether to output attention weights (required for some tasks)
|
|
|
|
| 319 |
|
| 320 |
Returns:
|
| 321 |
Tuple containing:
|
|
@@ -337,6 +339,18 @@ class ModelConverter:
|
|
| 337 |
if output_attentions:
|
| 338 |
conversion_args.append("--output_attentions")
|
| 339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
# Run the conversion
|
| 341 |
result = self._run_conversion_subprocess(
|
| 342 |
input_model_id, extra_args=conversion_args or None
|
|
@@ -521,6 +535,13 @@ def main():
|
|
| 521 |
"Whether to output attentions from the Whisper model. This is required for word-level (token) timestamps."
|
| 522 |
)
|
| 523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
# Determine output repository
|
| 525 |
# If user owns the model, allow uploading to the same repo
|
| 526 |
if config.hf_username == input_model_id.split("/")[0]:
|
|
@@ -559,6 +580,7 @@ def main():
|
|
| 559 |
input_model_id,
|
| 560 |
trust_remote_code=trust_remote_code,
|
| 561 |
output_attentions=output_attentions,
|
|
|
|
| 562 |
)
|
| 563 |
if not success:
|
| 564 |
st.error(f"Conversion failed: {stderr}")
|
|
|
|
| 309 |
input_model_id: str,
|
| 310 |
trust_remote_code: bool = False,
|
| 311 |
output_attentions: bool = False,
|
| 312 |
+
enable_task_inference: bool = True,
|
| 313 |
) -> Tuple[bool, Optional[str]]:
|
| 314 |
"""Convert a Hugging Face model to ONNX format.
|
| 315 |
|
|
|
|
| 317 |
input_model_id: Hugging Face model repository ID
|
| 318 |
trust_remote_code: Whether to trust and execute remote code from the model
|
| 319 |
output_attentions: Whether to output attention weights (required for some tasks)
|
| 320 |
+
enable_task_inference: Whether to pass the task argument to the conversion script based on the model's pipeline tag
|
| 321 |
|
| 322 |
Returns:
|
| 323 |
Tuple containing:
|
|
|
|
| 339 |
if output_attentions:
|
| 340 |
conversion_args.append("--output_attentions")
|
| 341 |
|
| 342 |
+
if enable_task_inference:
|
| 343 |
+
try:
|
| 344 |
+
info = model_info(
|
| 345 |
+
repo_id=input_model_id, token=self.config.hf_token
|
| 346 |
+
)
|
| 347 |
+
pipeline_tag = getattr(info, "pipeline_tag", None)
|
| 348 |
+
task = self._normalize_pipeline_tag(pipeline_tag)
|
| 349 |
+
if task:
|
| 350 |
+
conversion_args.extend(["--task", task])
|
| 351 |
+
except Exception:
|
| 352 |
+
pass
|
| 353 |
+
|
| 354 |
# Run the conversion
|
| 355 |
result = self._run_conversion_subprocess(
|
| 356 |
input_model_id, extra_args=conversion_args or None
|
|
|
|
| 535 |
"Whether to output attentions from the Whisper model. This is required for word-level (token) timestamps."
|
| 536 |
)
|
| 537 |
|
| 538 |
+
# Optional: Task inference toggle
|
| 539 |
+
enable_task_inference = st.toggle(
|
| 540 |
+
"Optional: Base the 'task' argument from the conversion script on the model's pipeline tag",
|
| 541 |
+
value=False,
|
| 542 |
+
help="This can make the conversion of some models work, but may cause issues for others. It's recommended to first try converting the model with this option disabled, and only enable it if the conversion fails.",
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
# Determine output repository
|
| 546 |
# If user owns the model, allow uploading to the same repo
|
| 547 |
if config.hf_username == input_model_id.split("/")[0]:
|
|
|
|
| 580 |
input_model_id,
|
| 581 |
trust_remote_code=trust_remote_code,
|
| 582 |
output_attentions=output_attentions,
|
| 583 |
+
enable_task_inference=enable_task_inference,
|
| 584 |
)
|
| 585 |
if not success:
|
| 586 |
st.error(f"Conversion failed: {stderr}")
|