Select pipeline_tag task by default.
Browse files
app.py
CHANGED
|
@@ -63,12 +63,26 @@ tasks_mapping = {
|
|
| 63 |
"semantic-segmentation": "Semantic Segmentation",
|
| 64 |
"seq2seq-lm": "Text to Text Generation",
|
| 65 |
"sequence-classification": "Text Classification",
|
| 66 |
-
"speech-seq2seq": "
|
| 67 |
"token-classification": "Token Classification",
|
| 68 |
}
|
| 69 |
reverse_tasks_mapping = {v: k for k, v in tasks_mapping.items()}
|
| 70 |
tasks_labels = list(tasks_mapping.keys())
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def error_str(error, title="Error", model=None, task=None, framework=None, compute_units=None, precision=None, tolerance=None, destination=None, open_discussion=True):
|
| 73 |
if not error: return ""
|
| 74 |
|
|
@@ -112,8 +126,18 @@ def get_pr_url(api, repo_id, title):
|
|
| 112 |
and discussion.title == title
|
| 113 |
):
|
| 114 |
return f"https://huggingface.co/{repo_id}/discussions/{discussion.num}"
|
| 115 |
-
|
| 116 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
"""
|
| 118 |
Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id.
|
| 119 |
Only PyTorch and Tensorflow are supported.
|
|
@@ -130,6 +154,7 @@ def on_model_change(model):
|
|
| 130 |
error = None
|
| 131 |
frameworks = []
|
| 132 |
selected_framework = None
|
|
|
|
| 133 |
|
| 134 |
try:
|
| 135 |
config_file = hf_hub_download(model, filename="config.json")
|
|
@@ -144,17 +169,30 @@ def on_model_change(model):
|
|
| 144 |
|
| 145 |
features = FeaturesManager.get_supported_features_for_model_type(model_type)
|
| 146 |
tasks = list(features.keys())
|
| 147 |
-
tasks = [tasks_mapping[task] for task in tasks]
|
| 148 |
|
| 149 |
-
|
|
|
|
| 150 |
selected_framework = frameworks[0] if len(frameworks) > 0 else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
except Exception as e:
|
| 152 |
error = e
|
| 153 |
model_type = None
|
| 154 |
|
| 155 |
return (
|
| 156 |
gr.update(visible=bool(model_type)), # Settings column
|
| 157 |
-
gr.update(choices=tasks, value=
|
| 158 |
gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), # Frameworks
|
| 159 |
gr.update(value=error_str(error, model=model)), # Error
|
| 160 |
)
|
|
|
|
| 63 |
"semantic-segmentation": "Semantic Segmentation",
|
| 64 |
"seq2seq-lm": "Text to Text Generation",
|
| 65 |
"sequence-classification": "Text Classification",
|
| 66 |
+
"speech-seq2seq": "Audio to Audio",
|
| 67 |
"token-classification": "Token Classification",
|
| 68 |
}
|
| 69 |
reverse_tasks_mapping = {v: k for k, v in tasks_mapping.items()}
|
| 70 |
tasks_labels = list(tasks_mapping.keys())
|
| 71 |
|
| 72 |
+
# Map pipeline_tag to internal exporters features/tasks
|
| 73 |
+
tags_to_tasks_mapping = {
|
| 74 |
+
"feature-extraction": "default",
|
| 75 |
+
"text-generation": "causal-lm",
|
| 76 |
+
"image-classification": "image-classification",
|
| 77 |
+
"image-segmentation": "image-segmentation",
|
| 78 |
+
"fill-mask": "masked-lm",
|
| 79 |
+
"object-detection": "object-detection",
|
| 80 |
+
"question-answering": "question-answering",
|
| 81 |
+
"text2text-generation": "seq2seq-lm",
|
| 82 |
+
"text-classification": "sequence-classification",
|
| 83 |
+
"token-classification": "token-classification",
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
def error_str(error, title="Error", model=None, task=None, framework=None, compute_units=None, precision=None, tolerance=None, destination=None, open_discussion=True):
|
| 87 |
if not error: return ""
|
| 88 |
|
|
|
|
| 126 |
and discussion.title == title
|
| 127 |
):
|
| 128 |
return f"https://huggingface.co/{repo_id}/discussions/{discussion.num}"
|
| 129 |
+
|
| 130 |
+
def retrieve_model_info(model_id):
|
| 131 |
+
api = HfApi()
|
| 132 |
+
model_info = api.model_info(model_id)
|
| 133 |
+
tags = model_info.tags
|
| 134 |
+
frameworks = [tag for tag in tags if tag in ["pytorch", "tf"]]
|
| 135 |
+
return {
|
| 136 |
+
"pipeline_tag": model_info.pipeline_tag,
|
| 137 |
+
"frameworks": sorted(["PyTorch" if f == "pytorch" else "TensorFlow" for f in frameworks]),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
def supported_frameworks(model_info):
|
| 141 |
"""
|
| 142 |
Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id.
|
| 143 |
Only PyTorch and Tensorflow are supported.
|
|
|
|
| 154 |
error = None
|
| 155 |
frameworks = []
|
| 156 |
selected_framework = None
|
| 157 |
+
selected_task = None
|
| 158 |
|
| 159 |
try:
|
| 160 |
config_file = hf_hub_download(model, filename="config.json")
|
|
|
|
| 169 |
|
| 170 |
features = FeaturesManager.get_supported_features_for_model_type(model_type)
|
| 171 |
tasks = list(features.keys())
|
|
|
|
| 172 |
|
| 173 |
+
model_info = retrieve_model_info(model)
|
| 174 |
+
frameworks = model_info["frameworks"]
|
| 175 |
selected_framework = frameworks[0] if len(frameworks) > 0 else None
|
| 176 |
+
|
| 177 |
+
pipeline_tag = model_info["pipeline_tag"]
|
| 178 |
+
# Select the task corresponding to the pipeline tag
|
| 179 |
+
if tasks:
|
| 180 |
+
if pipeline_tag in tags_to_tasks_mapping:
|
| 181 |
+
selected_task = tags_to_tasks_mapping[pipeline_tag]
|
| 182 |
+
else:
|
| 183 |
+
selected_task = tasks[0]
|
| 184 |
+
|
| 185 |
+
# Convert to UI labels
|
| 186 |
+
tasks = [tasks_mapping[task] for task in tasks]
|
| 187 |
+
selected_task = tasks_mapping[selected_task]
|
| 188 |
+
|
| 189 |
except Exception as e:
|
| 190 |
error = e
|
| 191 |
model_type = None
|
| 192 |
|
| 193 |
return (
|
| 194 |
gr.update(visible=bool(model_type)), # Settings column
|
| 195 |
+
gr.update(choices=tasks, value=selected_task), # Tasks
|
| 196 |
gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), # Frameworks
|
| 197 |
gr.update(value=error_str(error, model=model)), # Error
|
| 198 |
)
|