""" Core business logic for processing Qwen models to remove vision components. This module contains no Gradio dependencies and can be imported independently. """ import json import shutil import subprocess import tempfile import logging import time from pathlib import Path from safetensors import safe_open from safetensors.torch import save_file from huggingface_hub import create_repo, login, whoami # Set up logging logger = logging.getLogger(__name__) # Track operation timing for debugging operation_timings = {} # ---------------------------------------------------------------------- # Configuration: only tensors starting with these prefixes are removed. # For Qwen3.5 models, we need to remove both vision and MTP (Multi-Modal Transformer Processor) components. VISION_PREFIXES = ("model.visual.", "mtp.") # ---------------------------------------------------------------------- def generate_default_repo_name(model_url: str) -> str: """ Generate a default repository name from the model URL. Appends '-BLIND' to the model name to indicate it's a text-only version. Args: model_url: Hugging Face model URL or ID Returns: str: Default repository name """ # Clean up model URL model_id = model_url.strip() if "huggingface.co" in model_id: model_id = model_id.split("huggingface.co/")[-1].strip("/") # Extract the model name (after the last /) model_name = model_id.split("/")[-1] if "/" in model_id else model_id # Append -BLIND suffix return f"{model_name}-BLIND" # ---------------------------------------------------------------------- def filter_tensor_names(keys, prefixes): """Filter tensor keys into keep and remove lists based on prefixes.""" keep, remove = [], [] for k in keys: if any(k.startswith(p) for p in prefixes): remove.append(k) else: keep.append(k) return keep, remove def process_model(input_dir, output_dir, progress_callback=None): """ Process the model by removing vision components. Args: input_dir: Path to the downloaded model directory output_dir: Path to save the processed model progress_callback: Optional callback function for progress updates Returns: dict: Summary of the processing results """ input_dir = Path(input_dir) output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # ------------------------------------------------------------------ # 0. Calculate original model size (before processing) if progress_callback: progress_callback(0.05, "Calculating original model size...") original_size = 0 index_path = input_dir / "model.safetensors.index.json" if not index_path.exists(): raise FileNotFoundError("model.safetensors.index.json not found.") with open(index_path, "r") as f: index = json.load(f) weight_map = index["weight_map"] shard_files = sorted(set(weight_map.values())) total_shards = len(shard_files) for shard in shard_files: shard_path = input_dir / shard if shard_path.exists(): original_size += shard_path.stat().st_size # ------------------------------------------------------------------ # 1. Copy all non‑weight files (everything except .safetensors and .bin) if progress_callback: progress_callback(0.1, "Copying non-weight files...") for item in input_dir.iterdir(): if item.is_file() and not item.name.endswith((".safetensors", ".bin")): if item.name == "model.safetensors.index.json": continue # will recreate shutil.copy2(item, output_dir / item.name) # ------------------------------------------------------------------ # 2. Load the index to know which shards exist. index_path = input_dir / "model.safetensors.index.json" if not index_path.exists(): raise FileNotFoundError("model.safetensors.index.json not found.") with open(index_path, "r") as f: index = json.load(f) weight_map = index["weight_map"] shard_files = sorted(set(weight_map.values())) total_shards = len(shard_files) # ------------------------------------------------------------------ # 3. Process each shard. new_weight_map = {} removed_keys = [] total_size = 0 for i, shard in enumerate(shard_files): shard_path = input_dir / shard if progress_callback: progress_callback( 0.1 + (0.6 * (i + 1) / total_shards), f"Processing {shard} ({i+1}/{total_shards})...", ) # Load the shard and get all keys. tensors = {} with safe_open(shard_path, framework="pt", device="cpu") as f: all_keys = list(f.keys()) keep_keys, rem_keys = filter_tensor_names(all_keys, VISION_PREFIXES) removed_keys.extend(rem_keys) for key in keep_keys: tensors[key] = f.get_tensor(key) total_size += tensors[key].numel * 2 # float16 = 2 bytes if not tensors: print(f" Warning: no tensors kept in {shard}, skipping.") continue # Save the filtered shard. out_shard_path = output_dir / shard save_file(tensors, out_shard_path) # Update weight map. for key in tensors: new_weight_map[key] = shard # ------------------------------------------------------------------ # 4. Write a new index. if progress_callback: progress_callback(0.75, "Writing new index...") new_index = {"metadata": {"total_size": total_size}, "weight_map": new_weight_map} with open(output_dir / "model.safetensors.index.json", "w") as f: json.dump(new_index, f, indent=2) # ------------------------------------------------------------------ # 5. Modify config.json to remove vision sections. if progress_callback: progress_callback(0.85, "Updating config.json...") config_path = input_dir / "config.json" with open(config_path, "r") as f: config = json.load(f) # Remove vision‑related sections. config.pop("vision_config", None) config.pop("image_token_id", None) config.pop("vision_start_token_id", None) config.pop("vision_end_token_id", None) config.pop("video_token_id", None) # also remove video token if present # Remove MTP (Multi-Modal Transformer Processor) related config keys if "text_config" in config: config["text_config"].pop("mtp_num_hidden_layers", None) config["text_config"].pop("mtp_use_dedicated_embeddings", None) # Note: We keep the original architecture name since the underlying model structure # (MoE with linear attention) is the same, just without vision components. # Qwen3_5MoeForConditionalGeneration is supported in llama.cpp. with open(output_dir / "config.json", "w") as f: json.dump(config, f, indent=2) # ------------------------------------------------------------------ if progress_callback: progress_callback(1.0, "Processing complete!") return { "removed_tensors": len(removed_keys), "output_dir": str(output_dir), "total_size": total_size, "original_size": original_size, } def blindfold_model_impl( model_url: str, hf_token: str, repo_name: str, private_repo: bool, progress_callback=None, ): """ Internal implementation function that processes the model with a given token. This function has no Gradio dependencies. Args: model_url: Hugging Face model URL (e.g., "Qwen/Qwen3.5-35B-A3B") hf_token: Hugging Face access token repo_name: Name for the output repository private_repo: Whether to make the repo private progress_callback: Optional callback function for progress updates Returns: tuple: (status_message, output_url) """ global operation_timings start_time = time.time() logger.info( f"=== blindfold_model_impl STARTED at {time.strftime('%Y-%m-%d %H:%M:%S')} ===" ) logger.info(f"Model URL: {model_url}, Repo name: {repo_name}") try: # Clean up model URL model_id = model_url.strip() if "huggingface.co" in model_id: model_id = model_id.split("huggingface.co/")[-1].strip("/") # Validate token if not hf_token: logger.error("No token provided") return "❌ Error: Please login with your Hugging Face account first.", "" logger.info( f"Token type: {type(hf_token)}, Token length: {len(hf_token) if hf_token else 0}" ) # Verify token and get username # Note: OAuth tokens don't need login() - they can be used directly if progress_callback: progress_callback(0.05, "Authenticating with Hugging Face...") try: user_info = whoami(token=hf_token) except Exception as e: logger.error(f"Authentication failed: {e}") logger.error(f"Exception type: {type(e)}") import traceback logger.error(traceback.format_exc()) return f"❌ Error: Authentication failed - {str(e)}", "" username = user_info["name"] # Create temporary directories with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) clone_dir = temp_path / "source" clone_dir.mkdir(parents=True) # Clone source model repo using git if progress_callback: progress_callback( 0.1, f"Cloning model {model_id}... (this may take several minutes)" ) clone_url = f"https://oauth2:{hf_token}@huggingface.co/{model_id}" logger.info( f"Starting git clone of {model_id} at {time.strftime('%Y-%m-%d %H:%M:%S')}" ) clone_start = time.time() subprocess.run( ["git", "clone", "--verbose", clone_url, str(clone_dir)], check=True, capture_output=True, env={"GIT_TERMINAL_PROMPT": "0"}, ) clone_duration = time.time() - clone_start logger.info(f"Git clone completed in {clone_duration:.2f} seconds") operation_timings["git_clone"] = clone_duration # Send progress update after clone completes if progress_callback: progress_callback( 0.25, f"Clone completed in {clone_duration:.0f}s. Processing model...", ) # Process model in-place if progress_callback: progress_callback( 0.3, "Processing model (removing vision components)..." ) def inner_progress_callback(pct, msg): if progress_callback: progress_callback(0.3 + 0.5 * pct, msg) # Process in-place logger.info( f"Starting model processing at {time.strftime('%Y-%m-%d %H:%M:%S')}" ) process_start = time.time() result = process_model( input_dir=str(clone_dir), output_dir=str(clone_dir), # Process in-place progress_callback=inner_progress_callback, ) process_duration = time.time() - process_start logger.info(f"Model processing completed in {process_duration:.2f} seconds") operation_timings["model_processing"] = process_duration # Create output repo if progress_callback: progress_callback( 0.85, f"Creating repository {username}/{repo_name}..." ) create_repo( repo_id=repo_name, token=hf_token, private=private_repo, exist_ok=True, repo_type="model", ) # Clone output repo if progress_callback: progress_callback(0.9, f"Setting up output repository...") output_clone_dir = temp_path / "output" output_clone_dir.mkdir(parents=True) clone_url = ( f"https://oauth2:{hf_token}@huggingface.co/{username}/{repo_name}" ) subprocess.run( ["git", "clone", clone_url, str(output_clone_dir)], check=True, capture_output=True, env={"GIT_TERMINAL_PROMPT": "0"}, ) # Copy processed files to output repo if progress_callback: progress_callback(0.92, "Copying processed files...") for item in clone_dir.iterdir(): dest = output_clone_dir / item.name if item.is_file(): shutil.copy2(item, dest) elif item.is_dir(): shutil.copytree(item, dest, dirs_exist_ok=True) # Commit and push if progress_callback: progress_callback(0.95, "Committing changes...") subprocess.run( ["git", "add", "."], cwd=str(output_clone_dir), check=True, capture_output=True, env={"GIT_TERMINAL_PROMPT": "0"}, ) subprocess.run( ["git", "commit", "-m", "Remove vision components from model"], cwd=str(output_clone_dir), check=True, capture_output=True, env={"GIT_TERMINAL_PROMPT": "0"}, ) if progress_callback: progress_callback( 0.97, "Pushing to repository... (this may take several minutes)" ) logger.info(f"Starting git push at {time.strftime('%Y-%m-%d %H:%M:%S')}") push_start = time.time() subprocess.run( ["git", "push", "--verbose"], cwd=str(output_clone_dir), check=True, capture_output=True, env={"GIT_TERMINAL_PROMPT": "0"}, ) push_duration = time.time() - push_start logger.info(f"Git push completed in {push_duration:.2f} seconds") operation_timings["git_push"] = push_duration output_url = f"https://huggingface.co/{username}/{repo_name}" # Calculate size reduction original_size = result.get("original_size", 0) new_size = result.get("total_size", 0) removed_size = original_size - new_size reduction_pct = ( (removed_size / original_size * 100) if original_size > 0 else 0 ) total_duration = time.time() - start_time logger.info( f"=== blindfold_model_impl COMPLETED in {total_duration:.2f} seconds ===" ) logger.info(f"Operation timings: {operation_timings}") success_msg = f"""✅ **Success!** **Model processed and pushed successfully!** 📊 **Summary:** - Source model: `{model_id}` - Output repository: `{username}/{repo_name}` - Removed {result['removed_tensors']} vision-related tensors - Original size: {original_size:,} parameters ({original_size / (1024**3):.2f} GB) - New size: {new_size:,} parameters ({new_size / (1024**3):.2f} GB) - **Removed: {removed_size:,} parameters ({removed_size / (1024**3):.2f} GB) - {reduction_pct:.1f}% reduction** 🔗 **Access your model here:** {output_url}""" return success_msg, output_url except Exception as e: total_duration = time.time() - start_time logger.error( f"=== blindfold_model_impl FAILED after {total_duration:.2f} seconds ===" ) logger.error(f"Operation timings before failure: {operation_timings}") logger.error(f"Error: {str(e)}") logger.error(f"Exception type: {type(e).__name__}") import traceback logger.error(f"Traceback:\n{traceback.format_exc()}") error_msg = f"❌ **Error:** {str(e)}" return error_msg, ""