Blindfold / core.py
Flagstone8878's picture
Update core.py
70e0d4c verified
"""
Core business logic for processing Qwen models to remove vision components.
This module contains no Gradio dependencies and can be imported independently.
"""
import json
import shutil
import subprocess
import tempfile
import logging
import time
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import save_file
from huggingface_hub import create_repo, login, whoami
# Set up logging
logger = logging.getLogger(__name__)
# Track operation timing for debugging
operation_timings = {}
# ----------------------------------------------------------------------
# Configuration: only tensors starting with these prefixes are removed.
# For Qwen3.5 models, we need to remove both vision and MTP (Multi-Modal Transformer Processor) components.
VISION_PREFIXES = ("model.visual.", "mtp.")
# ----------------------------------------------------------------------
def generate_default_repo_name(model_url: str) -> str:
"""
Generate a default repository name from the model URL.
Appends '-BLIND' to the model name to indicate it's a text-only version.
Args:
model_url: Hugging Face model URL or ID
Returns:
str: Default repository name
"""
# Clean up model URL
model_id = model_url.strip()
if "huggingface.co" in model_id:
model_id = model_id.split("huggingface.co/")[-1].strip("/")
# Extract the model name (after the last /)
model_name = model_id.split("/")[-1] if "/" in model_id else model_id
# Append -BLIND suffix
return f"{model_name}-BLIND"
# ----------------------------------------------------------------------
def filter_tensor_names(keys, prefixes):
"""Filter tensor keys into keep and remove lists based on prefixes."""
keep, remove = [], []
for k in keys:
if any(k.startswith(p) for p in prefixes):
remove.append(k)
else:
keep.append(k)
return keep, remove
def process_model(input_dir, output_dir, progress_callback=None):
"""
Process the model by removing vision components.
Args:
input_dir: Path to the downloaded model directory
output_dir: Path to save the processed model
progress_callback: Optional callback function for progress updates
Returns:
dict: Summary of the processing results
"""
input_dir = Path(input_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# ------------------------------------------------------------------
# 0. Calculate original model size (before processing)
if progress_callback:
progress_callback(0.05, "Calculating original model size...")
original_size = 0
index_path = input_dir / "model.safetensors.index.json"
if not index_path.exists():
raise FileNotFoundError("model.safetensors.index.json not found.")
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
shard_files = sorted(set(weight_map.values()))
total_shards = len(shard_files)
for shard in shard_files:
shard_path = input_dir / shard
if shard_path.exists():
original_size += shard_path.stat().st_size
# ------------------------------------------------------------------
# 1. Copy all non‑weight files (everything except .safetensors and .bin)
if progress_callback:
progress_callback(0.1, "Copying non-weight files...")
for item in input_dir.iterdir():
if item.is_file() and not item.name.endswith((".safetensors", ".bin")):
if item.name == "model.safetensors.index.json":
continue # will recreate
shutil.copy2(item, output_dir / item.name)
# ------------------------------------------------------------------
# 2. Load the index to know which shards exist.
index_path = input_dir / "model.safetensors.index.json"
if not index_path.exists():
raise FileNotFoundError("model.safetensors.index.json not found.")
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
shard_files = sorted(set(weight_map.values()))
total_shards = len(shard_files)
# ------------------------------------------------------------------
# 3. Process each shard.
new_weight_map = {}
removed_keys = []
total_size = 0
for i, shard in enumerate(shard_files):
shard_path = input_dir / shard
if progress_callback:
progress_callback(
0.1 + (0.6 * (i + 1) / total_shards),
f"Processing {shard} ({i+1}/{total_shards})...",
)
# Load the shard and get all keys.
tensors = {}
with safe_open(shard_path, framework="pt", device="cpu") as f:
all_keys = list(f.keys())
keep_keys, rem_keys = filter_tensor_names(all_keys, VISION_PREFIXES)
removed_keys.extend(rem_keys)
for key in keep_keys:
tensors[key] = f.get_tensor(key)
total_size += tensors[key].numel * 2 # float16 = 2 bytes
if not tensors:
print(f" Warning: no tensors kept in {shard}, skipping.")
continue
# Save the filtered shard.
out_shard_path = output_dir / shard
save_file(tensors, out_shard_path)
# Update weight map.
for key in tensors:
new_weight_map[key] = shard
# ------------------------------------------------------------------
# 4. Write a new index.
if progress_callback:
progress_callback(0.75, "Writing new index...")
new_index = {"metadata": {"total_size": total_size}, "weight_map": new_weight_map}
with open(output_dir / "model.safetensors.index.json", "w") as f:
json.dump(new_index, f, indent=2)
# ------------------------------------------------------------------
# 5. Modify config.json to remove vision sections.
if progress_callback:
progress_callback(0.85, "Updating config.json...")
config_path = input_dir / "config.json"
with open(config_path, "r") as f:
config = json.load(f)
# Remove vision‑related sections.
config.pop("vision_config", None)
config.pop("image_token_id", None)
config.pop("vision_start_token_id", None)
config.pop("vision_end_token_id", None)
config.pop("video_token_id", None) # also remove video token if present
# Remove MTP (Multi-Modal Transformer Processor) related config keys
if "text_config" in config:
config["text_config"].pop("mtp_num_hidden_layers", None)
config["text_config"].pop("mtp_use_dedicated_embeddings", None)
# Note: We keep the original architecture name since the underlying model structure
# (MoE with linear attention) is the same, just without vision components.
# Qwen3_5MoeForConditionalGeneration is supported in llama.cpp.
with open(output_dir / "config.json", "w") as f:
json.dump(config, f, indent=2)
# ------------------------------------------------------------------
if progress_callback:
progress_callback(1.0, "Processing complete!")
return {
"removed_tensors": len(removed_keys),
"output_dir": str(output_dir),
"total_size": total_size,
"original_size": original_size,
}
def blindfold_model_impl(
model_url: str,
hf_token: str,
repo_name: str,
private_repo: bool,
progress_callback=None,
):
"""
Internal implementation function that processes the model with a given token.
This function has no Gradio dependencies.
Args:
model_url: Hugging Face model URL (e.g., "Qwen/Qwen3.5-35B-A3B")
hf_token: Hugging Face access token
repo_name: Name for the output repository
private_repo: Whether to make the repo private
progress_callback: Optional callback function for progress updates
Returns:
tuple: (status_message, output_url)
"""
global operation_timings
start_time = time.time()
logger.info(
f"=== blindfold_model_impl STARTED at {time.strftime('%Y-%m-%d %H:%M:%S')} ==="
)
logger.info(f"Model URL: {model_url}, Repo name: {repo_name}")
try:
# Clean up model URL
model_id = model_url.strip()
if "huggingface.co" in model_id:
model_id = model_id.split("huggingface.co/")[-1].strip("/")
# Validate token
if not hf_token:
logger.error("No token provided")
return "❌ Error: Please login with your Hugging Face account first.", ""
logger.info(
f"Token type: {type(hf_token)}, Token length: {len(hf_token) if hf_token else 0}"
)
# Verify token and get username
# Note: OAuth tokens don't need login() - they can be used directly
if progress_callback:
progress_callback(0.05, "Authenticating with Hugging Face...")
try:
user_info = whoami(token=hf_token)
except Exception as e:
logger.error(f"Authentication failed: {e}")
logger.error(f"Exception type: {type(e)}")
import traceback
logger.error(traceback.format_exc())
return f"❌ Error: Authentication failed - {str(e)}", ""
username = user_info["name"]
# Create temporary directories
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
clone_dir = temp_path / "source"
clone_dir.mkdir(parents=True)
# Clone source model repo using git
if progress_callback:
progress_callback(
0.1, f"Cloning model {model_id}... (this may take several minutes)"
)
clone_url = f"https://oauth2:{hf_token}@huggingface.co/{model_id}"
logger.info(
f"Starting git clone of {model_id} at {time.strftime('%Y-%m-%d %H:%M:%S')}"
)
clone_start = time.time()
subprocess.run(
["git", "clone", "--verbose", clone_url, str(clone_dir)],
check=True,
capture_output=True,
env={"GIT_TERMINAL_PROMPT": "0"},
)
clone_duration = time.time() - clone_start
logger.info(f"Git clone completed in {clone_duration:.2f} seconds")
operation_timings["git_clone"] = clone_duration
# Send progress update after clone completes
if progress_callback:
progress_callback(
0.25,
f"Clone completed in {clone_duration:.0f}s. Processing model...",
)
# Process model in-place
if progress_callback:
progress_callback(
0.3, "Processing model (removing vision components)..."
)
def inner_progress_callback(pct, msg):
if progress_callback:
progress_callback(0.3 + 0.5 * pct, msg)
# Process in-place
logger.info(
f"Starting model processing at {time.strftime('%Y-%m-%d %H:%M:%S')}"
)
process_start = time.time()
result = process_model(
input_dir=str(clone_dir),
output_dir=str(clone_dir), # Process in-place
progress_callback=inner_progress_callback,
)
process_duration = time.time() - process_start
logger.info(f"Model processing completed in {process_duration:.2f} seconds")
operation_timings["model_processing"] = process_duration
# Create output repo
if progress_callback:
progress_callback(
0.85, f"Creating repository {username}/{repo_name}..."
)
create_repo(
repo_id=repo_name,
token=hf_token,
private=private_repo,
exist_ok=True,
repo_type="model",
)
# Clone output repo
if progress_callback:
progress_callback(0.9, f"Setting up output repository...")
output_clone_dir = temp_path / "output"
output_clone_dir.mkdir(parents=True)
clone_url = (
f"https://oauth2:{hf_token}@huggingface.co/{username}/{repo_name}"
)
subprocess.run(
["git", "clone", clone_url, str(output_clone_dir)],
check=True,
capture_output=True,
env={"GIT_TERMINAL_PROMPT": "0"},
)
# Copy processed files to output repo
if progress_callback:
progress_callback(0.92, "Copying processed files...")
for item in clone_dir.iterdir():
dest = output_clone_dir / item.name
if item.is_file():
shutil.copy2(item, dest)
elif item.is_dir():
shutil.copytree(item, dest, dirs_exist_ok=True)
# Commit and push
if progress_callback:
progress_callback(0.95, "Committing changes...")
subprocess.run(
["git", "add", "."],
cwd=str(output_clone_dir),
check=True,
capture_output=True,
env={"GIT_TERMINAL_PROMPT": "0"},
)
subprocess.run(
["git", "commit", "-m", "Remove vision components from model"],
cwd=str(output_clone_dir),
check=True,
capture_output=True,
env={"GIT_TERMINAL_PROMPT": "0"},
)
if progress_callback:
progress_callback(
0.97, "Pushing to repository... (this may take several minutes)"
)
logger.info(f"Starting git push at {time.strftime('%Y-%m-%d %H:%M:%S')}")
push_start = time.time()
subprocess.run(
["git", "push", "--verbose"],
cwd=str(output_clone_dir),
check=True,
capture_output=True,
env={"GIT_TERMINAL_PROMPT": "0"},
)
push_duration = time.time() - push_start
logger.info(f"Git push completed in {push_duration:.2f} seconds")
operation_timings["git_push"] = push_duration
output_url = f"https://huggingface.co/{username}/{repo_name}"
# Calculate size reduction
original_size = result.get("original_size", 0)
new_size = result.get("total_size", 0)
removed_size = original_size - new_size
reduction_pct = (
(removed_size / original_size * 100) if original_size > 0 else 0
)
total_duration = time.time() - start_time
logger.info(
f"=== blindfold_model_impl COMPLETED in {total_duration:.2f} seconds ==="
)
logger.info(f"Operation timings: {operation_timings}")
success_msg = f"""βœ… **Success!**
**Model processed and pushed successfully!**
πŸ“Š **Summary:**
- Source model: `{model_id}`
- Output repository: `{username}/{repo_name}`
- Removed {result['removed_tensors']} vision-related tensors
- Original size: {original_size:,} parameters ({original_size / (1024**3):.2f} GB)
- New size: {new_size:,} parameters ({new_size / (1024**3):.2f} GB)
- **Removed: {removed_size:,} parameters ({removed_size / (1024**3):.2f} GB) - {reduction_pct:.1f}% reduction**
πŸ”— **Access your model here:** {output_url}"""
return success_msg, output_url
except Exception as e:
total_duration = time.time() - start_time
logger.error(
f"=== blindfold_model_impl FAILED after {total_duration:.2f} seconds ==="
)
logger.error(f"Operation timings before failure: {operation_timings}")
logger.error(f"Error: {str(e)}")
logger.error(f"Exception type: {type(e).__name__}")
import traceback
logger.error(f"Traceback:\n{traceback.format_exc()}")
error_msg = f"❌ **Error:** {str(e)}"
return error_msg, ""