convert / app.py
atharva98's picture
Upload 2 files
68671d2 verified
import os
from typing import Optional
import gradio as gr
from huggingface_hub import HfApi
from convert import convert
DATASET_REPO_URL = "https://huggingface.co/datasets/safetensors/conversions"
HF_TOKEN = os.environ.get("HF_TOKEN") # Bot token (recommended for opening PRs)
def _format_error(msg: str) -> str:
return f"""
### Error 😢
{msg}
"""
def _format_invalid(msg: str) -> str:
return f"""
### Invalid input 🐞
{msg}
"""
def run(model_id: str, is_private: bool, token: Optional[str] = None) -> str:
model_id = (model_id or "").strip()
token = (token or "").strip() if token else None
if not model_id:
return _format_invalid("Please provide a valid `model_id`.")
# Token policy:
# - Public model: use bot token if available; otherwise unauthenticated may fail for PR creation.
# - Private model: must use user-provided token to read; PR should still be created with bot if allowed.
if is_private and not token:
return _format_invalid("Private model checked: please provide a read-access HF token.")
# Default API client:
# - If private: use user's token to read model info (must have access).
# - Else: use bot token if available, else try unauthenticated.
api = HfApi(token=token if is_private else (HF_TOKEN if HF_TOKEN else None))
try:
info = api.model_info(repo_id=model_id)
hf_is_private = bool(getattr(info, "private", False))
# If user checked "private" but the model is actually public, switch to bot token for PR creation.
# This mirrors your original intent: PR should be created "on behalf of bot" when possible.
if is_private and not hf_is_private:
api = HfApi(token=HF_TOKEN if HF_TOKEN else None)
# If model is public but we don't have a token, conversion may run but PR creation can fail.
if (not is_private) and (HF_TOKEN is None):
# Still attempt, but give a clear warning if it fails.
pass
commit_info, errors = convert(api=api, model_id=model_id)
pr_url = getattr(commit_info, "pr_url", None)
if not pr_url:
# Some implementations might return commit URL instead of PR url.
pr_url = getattr(commit_info, "url", None)
out = """
### Success 🔥
Yay! This model was successfully converted.
"""
if pr_url:
out += f"\nPR/Discussion URL:\n\n[{pr_url}]({pr_url})\n"
else:
out += "\n(Conversion succeeded, but no PR URL was returned by `convert()`.)\n"
if errors:
out += "\n\n### Conversion warnings\n"
out += "\n".join(
f"- `{filename}`: {e} (skipped)"
for filename, e in errors
)
# Helpful note if bot token is missing
if (not is_private) and (HF_TOKEN is None):
out += "\n\n⚠️ Note: `HF_TOKEN` is not set. If PR creation fails, set `HF_TOKEN` in the environment."
return out
except Exception as e:
return _format_error(str(e))
DESCRIPTION = """
The steps are the following:
- (Optional) Set `HF_TOKEN` as an environment variable for your bot/account token to open PRs reliably.
- If the model is private, paste a read-access token from hf.co/settings/tokens.
- Input a model id from the Hub (e.g. `org/model_name`)
- Click "Submit"
- You'll get the URL of the opened PR (or discussion) if successful.
⚠️ For now only `pytorch_model.bin` files are supported but we’ll extend in the future.
"""
title = "Convert any model to Safetensors and open a PR"
def toggle_token_visibility(is_private: bool):
# Update component configuration (visibility) the Gradio-recommended way.
return gr.Textbox(visible=is_private)
with gr.Blocks(title=title) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column():
model_id = gr.Textbox(label="model_id", placeholder="e.g. bigscience/bloom", max_lines=1)
is_private = gr.Checkbox(label="Private model", value=False)
token = gr.Textbox(
label="your_hf_token",
placeholder="hf_...",
max_lines=1,
type="password",
visible=False,
)
with gr.Row():
submit = gr.Button("Submit", variant="primary")
# ClearButton should be instantiated with components to clear
clean = gr.ClearButton([model_id, is_private, token])
with gr.Column():
output = gr.Markdown()
is_private.change(fn=toggle_token_visibility, inputs=is_private, outputs=token)
submit.click(fn=run, inputs=[model_id, is_private, token], outputs=output, concurrency_limit=1)
demo.queue(max_size=10).launch(show_api=True)