Spaces:
Sleeping
Sleeping
| import os | |
| import subprocess | |
| import streamlit as st | |
| from huggingface_hub import hf_hub_download | |
| class BitNetManager: | |
| def __init__(self, repo_url="https://github.com/microsoft/BitNet.git"): | |
| self.repo_url = repo_url | |
| self.base_dir = os.path.dirname(os.path.abspath(__file__)) | |
| self.bitnet_dir = os.path.join(self.base_dir, "BitNet") | |
| self.build_dir = os.path.join(self.bitnet_dir, "build") | |
| def patch_source(self): | |
| """Patch the BitNet source code to fix known compilation errors.""" | |
| target_file = os.path.join(self.bitnet_dir, "src", "ggml-bitnet-mad.cpp") | |
| if os.path.exists(target_file): | |
| st.info("Patching source code to fix const-qualifier error...") | |
| with open(target_file, "r") as f: | |
| content = f.read() | |
| # Change 'int8_t * y_col = y + col * by;' to 'const int8_t * y_col = y + col * by;' | |
| # This fixes the "cannot initialize a variable of type 'int8_t *' with an rvalue of type 'const int8_t *'" error | |
| old_str = "int8_t * y_col = y + col * by;" | |
| new_str = "const int8_t * y_col = y + col * by;" | |
| if old_str in content: | |
| patched_content = content.replace(old_str, new_str) | |
| with open(target_file, "w") as f: | |
| f.write(patched_content) | |
| st.success("Patch applied to ggml-bitnet-mad.cpp!") | |
| else: | |
| st.warning("Patch target line not found in ggml-bitnet-mad.cpp.") | |
| # Patch setup_env.py to fix potential path or environment issues | |
| setup_script = os.path.join(self.bitnet_dir, "setup_env.py") | |
| if os.path.exists(setup_script): | |
| st.info("Patching setup_env.py to use Python API instead of CLI...") | |
| with open(setup_script, "r") as f: | |
| setup_content = f.read() | |
| # The line we want to replace | |
| old_line = 'run_command(["huggingface-cli", "download", hf_url, "--local-dir", model_dir], log_step="download_model")' | |
| # The replacement using Python API | |
| new_line = 'from huggingface_hub import snapshot_download; snapshot_download(repo_id=hf_url, local_dir=model_dir)' | |
| if old_line in setup_content: | |
| patched_setup = setup_content.replace(old_line, new_line) | |
| with open(setup_script, "w") as f: | |
| f.write(patched_setup) | |
| st.success("Successfully patched setup_env.py with Python API!") | |
| elif 'huggingface-cli' in setup_content: | |
| # Fallback: if they used different quotes or slightly different structure | |
| # We'll try to find any list containing "huggingface-cli" and "download" | |
| import re | |
| # This regex looks for run_command([..."huggingface-cli"..."download"...]) | |
| pattern = r'run_command\(\s*\[\s*["\']huggingface-cli["\'],\s*["\']download["\'],[^\]]+\][^)]*\)' | |
| matches = re.findall(pattern, setup_content) | |
| if matches: | |
| patched_setup = re.sub(pattern, new_line, setup_content) | |
| with open(setup_script, "w") as f: | |
| f.write(patched_setup) | |
| st.success("Successfully patched setup_env.py (via regex)!") | |
| else: | |
| st.warning("Could not find the exact download command in setup_env.py to patch.") | |
| pass | |
| def setup_engine(self, model_id="1bitLLM/bitnet_b1_58-3B"): | |
| """Clone and compile utilizing official setup_env.py with log streaming.""" | |
| model_name = model_id.split("/")[-1] | |
| model_path = os.path.join(self.bitnet_dir, "models", model_name, "ggml-model-i2_s.gguf") | |
| binary = self.get_binary_path() | |
| # Check if already compiled AND model exists | |
| if binary and os.path.exists(binary) and os.path.exists(model_path): | |
| st.success(f"BitNet engine and model ({model_name}) are ready!") | |
| return True | |
| if binary and os.path.exists(binary): | |
| st.info(f"Engine binary found, but model weights for {model_name} are missing. Starting setup...") | |
| if not os.path.exists(self.bitnet_dir): | |
| st.info("Cloning BitNet repository...") | |
| subprocess.run(["git", "clone", "--recursive", self.repo_url], check=True) | |
| self.patch_source() | |
| st.info("Running official BitNet setup (setup_env.py)...") | |
| try: | |
| # -u for unbuffered output to see logs in real-time | |
| cmd = ["python", "-u", "setup_env.py", "--hf-repo", "1bitLLM/bitnet_b1_58-3B", "--use-pretuned"] | |
| # Stream the stdout to Streamlit in real-time | |
| process = subprocess.Popen(cmd, cwd=self.bitnet_dir, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) | |
| log_container = st.empty() | |
| logs = [] | |
| for line in process.stdout: | |
| logs.append(line) | |
| # Keep only the last 15 lines or so for UI clarity | |
| log_container.code("".join(logs[-15:])) | |
| process.wait() | |
| if process.returncode != 0: | |
| st.error(f"Setup failed (Exit {process.returncode})") | |
| # Check for specific logs | |
| logs_dir = os.path.join(self.bitnet_dir, "logs") | |
| comp_log = os.path.join(logs_dir, "compile.log") | |
| down_log = os.path.join(logs_dir, "download_model.log") | |
| if os.path.exists(down_log): | |
| st.info("Download Log (logs/download_model.log):") | |
| with open(down_log, "r") as f: | |
| st.code(f.read()[-3000:]) | |
| elif os.path.exists(comp_log): | |
| st.info("Compilation Log (logs/compile.log):") | |
| with open(comp_log, "r") as f: | |
| st.code(f.read()[-3000:]) | |
| else: | |
| st.info("Detailed Output:") | |
| st.code("".join(logs)[-2000:]) | |
| return False | |
| st.success("Official Setup Completed Successfully!") | |
| return True | |
| except Exception as e: | |
| st.error(f"Execution error durante setup: {e}") | |
| return False | |
| def get_binary_path(self): | |
| """Locate the bitnet binary based on platform/build structure.""" | |
| possible_paths = [ | |
| os.path.join(self.bitnet_dir, "build", "bin", "llama-cli"), # Standard location | |
| os.path.join(self.bitnet_dir, "build", "llama-cli"), # Alternate location | |
| os.path.join(self.bitnet_dir, "build", "bitnet"), # Legacy/Custom | |
| os.path.join(self.bitnet_dir, "build", "bin", "bitnet"), | |
| os.path.join(self.bitnet_dir, "build", "Release", "bitnet.exe"), # Windows | |
| os.path.join(self.bitnet_dir, "build", "bin", "Release", "llama-cli.exe"), | |
| os.path.join(self.bitnet_dir, "run_inference.py") # Script fallback | |
| ] | |
| for p in possible_paths: | |
| if os.path.exists(p): | |
| return p | |
| return None | |
| def download_model(self, model_id="1bitLLM/bitnet_b1_58-3B", filename="ggml-model-i2_s.gguf"): | |
| """Locate the model weights. These are generated locally by setup_env.py.""" | |
| # setup_env.py downloads weights to models/<model_name>/ | |
| # e.g. models/bitnet_b1_58-3B/ | |
| model_name = model_id.split("/")[-1] | |
| local_model_path = os.path.join(self.bitnet_dir, "models", model_name, filename) | |
| if os.path.exists(local_model_path): | |
| st.success(f"Found local model: {model_name}") | |
| return local_model_path | |
| st.error(f"Model file not found at {local_model_path}") | |
| st.info("The GGUF model must be generated by the 'Initialize Engine' process. Please run it again to download and convert the weights.") | |
| return None | |
| def run_inference(self, prompt, model_path): | |
| """Execute the bitnet binary with the provided prompt.""" | |
| binary = self.get_binary_path() | |
| if not binary: | |
| st.error("Inference binary not found. Please re-run Initialization.") | |
| return None | |
| # Build the command. bitnet binary usually takes -m and -p | |
| if binary.endswith(".py"): | |
| cmd = ["python", binary, "-m", model_path, "-p", prompt, "-n", "128"] | |
| else: | |
| cmd = [binary, "-m", model_path, "-p", prompt, "-n", "128"] | |
| try: | |
| # We'll return a Popen object so the app can stream the response | |
| # CRITICAL: We must set cwd=self.bitnet_dir so run_inference.py can find build/ | |
| process = subprocess.Popen( | |
| cmd, | |
| cwd=self.bitnet_dir, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True, | |
| bufsize=1 | |
| ) | |
| return process | |
| except Exception as e: | |
| st.error(f"Inference execution failed: {e}") | |
| return None | |