Blindfold / ui.py
Flagstone8878's picture
Update ui.py
7d69a83 verified
"""
Gradio UI for the Blindfold Model tool.
This module contains all Gradio-specific code and handles the OAuth integration.
"""
import gradio as gr
import logging
import time
from core import blindfold_model_impl, generate_default_repo_name
logger = logging.getLogger(__name__)
def blindfold_model(
model_url: str,
repo_name: str,
private_repo: bool,
oauth_token: gr.OAuthToken | None = None,
progress: gr.Progress = gr.Progress(),
):
"""
Main function to clone, process, and push a model using git.
The OAuth token is automatically injected by Gradio when the user logs in.
Args:
model_url: Hugging Face model URL (e.g., "Qwen/Qwen3.5-35B-A3B")
repo_name: Name for the output repository
private_repo: Whether to make the repo private
oauth_token: Gradio OAuth token (automatically injected)
progress: Gradio progress tracker
Returns:
tuple: (status_message, output_url)
"""
start_time = time.time()
logger.info(
f"=== Gradio blindfold_model called at {time.strftime('%Y-%m-%d %H:%M:%S')} ==="
)
logger.info(
f"Model URL: {model_url}, Repo name: {repo_name}, Private: {private_repo}"
)
logger.info(f"OAuth token present: {oauth_token is not None}")
# Get token from OAuth token
# OAuthToken has a 'token' attribute that contains the access token string
if oauth_token:
hf_token = oauth_token.token
else:
hf_token = None
# Create a progress callback wrapper for Gradio
def progress_callback(pct, msg):
progress(pct, msg)
result = blindfold_model_impl(
model_url, hf_token, repo_name, private_repo, progress_callback
)
duration = time.time() - start_time
logger.info(f"=== Gradio blindfold_model returned after {duration:.2f} seconds ===")
return result
def create_interface():
"""Create and return Gradio interface with OAuth support."""
with gr.Blocks(title="Blindfold Model - Remove Vision Components") as interface:
gr.Markdown(
"""
# πŸ™ˆ Blindfold Model Tool
Remove vision components from Qwen models to create text-only versions.
This tool clones a model from Hugging Face, removes vision-related tensors,
and pushes the processed model to a new repository.
**Login with your Hugging Face account to get started.**
"""
)
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### πŸ“₯ Input Settings")
# Add LoginButton for OAuth authentication
login_btn = gr.LoginButton(elem_id="hf_login")
model_url = gr.Textbox(
label="Hugging Face Model URL or ID",
placeholder="e.g., Qwen/Qwen3.5-35B-A3B",
value="",
)
repo_name = gr.Textbox(
label="Output Repository Name",
placeholder="e.g., Qwen3.5-35B-A3B-BLIND",
value="",
)
private_repo = gr.Checkbox(label="Make repository private", value=False)
submit_btn = gr.Button(
"πŸš€ Process & Upload", variant="primary", size="lg"
)
# Update repo_name when model_url changes
model_url.change(
fn=generate_default_repo_name,
inputs=[model_url],
outputs=[repo_name],
)
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Output")
output_status = gr.Markdown(
"Ready to process. Login with your Hugging Face account first.",
label="Status",
)
output_url = gr.Textbox(
label="Output Model URL",
interactive=False,
placeholder="URL will appear here after completion",
)
gr.Markdown(
"""
---
### πŸ“‹ How it works:
1. Login with your Hugging Face account (OAuth)
2. Enter the source model URL/ID
3. Clones model using git
4. Removes all vision-related tensors (`model.visual.*` and `mtp.*`)
5. Updates model config to remove vision sections
6. Creates a new repository on your Hugging Face account
7. Pushes the processed model using git
### ⚠️ Requirements:
- A Hugging Face account with **write** permissions
- The source model must use safetensors format
### πŸ”’ Privacy:
Your Hugging Face token is handled securely through OAuth and is never stored.
"""
)
# Use LoginButton for Hugging Face OAuth authentication
# The OAuth token is automatically injected based on the type hint in blindfold_model
login_btn.click(
fn=blindfold_model,
inputs=[model_url, repo_name, private_repo],
outputs=[output_status, output_url],
)
# Also connect the submit button to the same function
submit_btn.click(
fn=blindfold_model,
inputs=[model_url, repo_name, private_repo],
outputs=[output_status, output_url],
)
return interface