import gradio as gr import torch import os import gc import shutil import requests import json import struct import numpy as np import re from pathlib import Path from typing import Dict, Any, Optional, List from huggingface_hub import HfApi, hf_hub_download, list_repo_files, login from safetensors.torch import load_file, save_file from tqdm import tqdm # --- Memory Efficient Safetensors --- class MemoryEfficientSafeOpen: """ Reads safetensors metadata and tensors without mmap, keeping RAM usage low. """ def __init__(self, filename): self.filename = filename self.file = open(filename, "rb") self.header, self.header_size = self._read_header() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.file.close() def keys(self) -> list[str]: return [k for k in self.header.keys() if k != "__metadata__"] def metadata(self) -> Dict[str, str]: return self.header.get("__metadata__", {}) def get_tensor(self, key): if key not in self.header: raise KeyError(f"Tensor '{key}' not found in the file") metadata = self.header[key] offset_start, offset_end = metadata["data_offsets"] self.file.seek(self.header_size + 8 + offset_start) tensor_bytes = self.file.read(offset_end - offset_start) return self._deserialize_tensor(tensor_bytes, metadata) def _read_header(self): header_size = struct.unpack("= self.max_bytes: self.flush() def flush(self): if not self.buffer: return self.shard_count += 1 # Placeholder filename, will rename later or use sequential numbering shard_name = f"model-{self.shard_count:05d}.safetensors" # Suffix to be fixed at end? # Actually, standard is model-00001-of-XXXXX. # Since we don't know total count yet, we use a temp naming scheme, # OR we just use model-00001.safetensors and fix the index.json later. # Diffusers accepts model-xxxxx-of-xxxxx. # We will use "model-xxxxx.safetensors" and rename locally if needed, # but for simple uploading we can just assume we don't know the total yet. # Actually, let's just count up. model-00001.safetensors is fine if we update index. print(f"Flushing Shard {self.shard_count} ({self.current_bytes / 1024**3:.2f} GB)...") # Construct Header header = {"__metadata__": {"format": "pt"}} current_offset = 0 for item in self.buffer: header[item["key"]] = { "dtype": item["dtype"], "shape": item["shape"], "data_offsets": [current_offset, current_offset + len(item["data"])] } current_offset += len(item["data"]) self.index_map[item["key"]] = shard_name header_json = json.dumps(header).encode('utf-8') # Write File out_path = self.output_dir / shard_name with open(out_path, 'wb') as f: f.write(struct.pack('