File size: 9,446 Bytes
2544845
 
 
 
 
 
 
 
 
 
 
 
7d4c953
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f478088
7d4c953
f478088
7d4c953
39f5a1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a60b206
2363817
b4c8d46
804d64f
b4c8d46
 
f478088
b4c8d46
 
 
 
f478088
b4c8d46
 
 
f478088
2544845
 
 
 
7d4c953
 
804d64f
9136547
f478088
 
9136547
804d64f
 
 
 
 
 
 
 
 
 
 
 
2544845
9136547
804d64f
d288235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
053b326
2544845
9136547
 
 
 
804d64f
9136547
2544845
804d64f
 
 
ca64889
 
 
 
804d64f
ca64889
804d64f
 
 
 
 
 
 
 
b4c8d46
 
 
 
 
 
804d64f
b4c8d46
804d64f
 
b4c8d46
 
 
2544845
 
804d64f
 
 
 
 
2544845
804d64f
 
 
 
 
 
2544845
804d64f
ca64889
2544845
 
ca64889
2544845
 
804d64f
 
2544845
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import os
import subprocess
import streamlit as st
from huggingface_hub import hf_hub_download

class BitNetManager:
    def __init__(self, repo_url="https://github.com/microsoft/BitNet.git"):
        self.repo_url = repo_url
        self.base_dir = os.path.dirname(os.path.abspath(__file__))
        self.bitnet_dir = os.path.join(self.base_dir, "BitNet")
        self.build_dir = os.path.join(self.bitnet_dir, "build")
        
    def patch_source(self):
        """Patch the BitNet source code to fix known compilation errors."""
        target_file = os.path.join(self.bitnet_dir, "src", "ggml-bitnet-mad.cpp")
        if os.path.exists(target_file):
            st.info("Patching source code to fix const-qualifier error...")
            with open(target_file, "r") as f:
                content = f.read()
            
            # Change 'int8_t * y_col = y + col * by;' to 'const int8_t * y_col = y + col * by;'
            # This fixes the "cannot initialize a variable of type 'int8_t *' with an rvalue of type 'const int8_t *'" error
            old_str = "int8_t * y_col = y + col * by;"
            new_str = "const int8_t * y_col = y + col * by;"
            
            if old_str in content:
                patched_content = content.replace(old_str, new_str)
                with open(target_file, "w") as f:
                    f.write(patched_content)
                st.success("Patch applied to ggml-bitnet-mad.cpp!")
            else:
                st.warning("Patch target line not found in ggml-bitnet-mad.cpp.")

        # Patch setup_env.py to fix potential path or environment issues
        setup_script = os.path.join(self.bitnet_dir, "setup_env.py")
        if os.path.exists(setup_script):
            st.info("Patching setup_env.py to use Python API instead of CLI...")
            with open(setup_script, "r") as f:
                setup_content = f.read()
            
            # The line we want to replace
            old_line = 'run_command(["huggingface-cli", "download", hf_url, "--local-dir", model_dir], log_step="download_model")'
            # The replacement using Python API
            new_line = 'from huggingface_hub import snapshot_download; snapshot_download(repo_id=hf_url, local_dir=model_dir)'
            
            if old_line in setup_content:
                patched_setup = setup_content.replace(old_line, new_line)
                with open(setup_script, "w") as f:
                    f.write(patched_setup)
                st.success("Successfully patched setup_env.py with Python API!")
            elif 'huggingface-cli' in setup_content:
                # Fallback: if they used different quotes or slightly different structure
                # We'll try to find any list containing "huggingface-cli" and "download"
                import re
                # This regex looks for run_command([..."huggingface-cli"..."download"...])
                pattern = r'run_command\(\s*\[\s*["\']huggingface-cli["\'],\s*["\']download["\'],[^\]]+\][^)]*\)'
                matches = re.findall(pattern, setup_content)
                if matches:
                    patched_setup = re.sub(pattern, new_line, setup_content)
                    with open(setup_script, "w") as f:
                        f.write(patched_setup)
                    st.success("Successfully patched setup_env.py (via regex)!")
                else:
                    st.warning("Could not find the exact download command in setup_env.py to patch.")
        pass

    def setup_engine(self, model_id="1bitLLM/bitnet_b1_58-3B"):
        """Clone and compile utilizing official setup_env.py with log streaming."""
        model_name = model_id.split("/")[-1]
        model_path = os.path.join(self.bitnet_dir, "models", model_name, "ggml-model-i2_s.gguf")
        binary = self.get_binary_path()
        
        # Check if already compiled AND model exists
        if binary and os.path.exists(binary) and os.path.exists(model_path):
            st.success(f"BitNet engine and model ({model_name}) are ready!")
            return True
        
        if binary and os.path.exists(binary):
            st.info(f"Engine binary found, but model weights for {model_name} are missing. Starting setup...")

        if not os.path.exists(self.bitnet_dir):
            st.info("Cloning BitNet repository...")
            subprocess.run(["git", "clone", "--recursive", self.repo_url], check=True)
        
        self.patch_source()
        
        st.info("Running official BitNet setup (setup_env.py)...")
        try:
            # -u for unbuffered output to see logs in real-time
            cmd = ["python", "-u", "setup_env.py", "--hf-repo", "1bitLLM/bitnet_b1_58-3B", "--use-pretuned"]
            
            # Stream the stdout to Streamlit in real-time
            process = subprocess.Popen(cmd, cwd=self.bitnet_dir, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
            
            log_container = st.empty()
            logs = []
            
            for line in process.stdout:
                logs.append(line)
                # Keep only the last 15 lines or so for UI clarity
                log_container.code("".join(logs[-15:]))
            
            process.wait()
            
            if process.returncode != 0:
                st.error(f"Setup failed (Exit {process.returncode})")
                
                # Check for specific logs
                logs_dir = os.path.join(self.bitnet_dir, "logs")
                comp_log = os.path.join(logs_dir, "compile.log")
                down_log = os.path.join(logs_dir, "download_model.log")
                
                if os.path.exists(down_log):
                    st.info("Download Log (logs/download_model.log):")
                    with open(down_log, "r") as f:
                        st.code(f.read()[-3000:])
                elif os.path.exists(comp_log):
                    st.info("Compilation Log (logs/compile.log):")
                    with open(comp_log, "r") as f:
                        st.code(f.read()[-3000:])
                else:
                    st.info("Detailed Output:")
                    st.code("".join(logs)[-2000:])
                return False
            
            st.success("Official Setup Completed Successfully!")
            return True
        except Exception as e:
            st.error(f"Execution error durante setup: {e}")
            return False

    def get_binary_path(self):
        """Locate the bitnet binary based on platform/build structure."""
        possible_paths = [
            os.path.join(self.bitnet_dir, "build", "bin", "llama-cli"), # Standard location
            os.path.join(self.bitnet_dir, "build", "llama-cli"), # Alternate location
            os.path.join(self.bitnet_dir, "build", "bitnet"), # Legacy/Custom
            os.path.join(self.bitnet_dir, "build", "bin", "bitnet"),
            os.path.join(self.bitnet_dir, "build", "Release", "bitnet.exe"), # Windows
            os.path.join(self.bitnet_dir, "build", "bin", "Release", "llama-cli.exe"),
            os.path.join(self.bitnet_dir, "run_inference.py") # Script fallback
        ]
        for p in possible_paths:
            if os.path.exists(p):
                return p
        return None

    def download_model(self, model_id="1bitLLM/bitnet_b1_58-3B", filename="ggml-model-i2_s.gguf"):
        """Locate the model weights. These are generated locally by setup_env.py."""
        # setup_env.py downloads weights to models/<model_name>/
        # e.g. models/bitnet_b1_58-3B/
        model_name = model_id.split("/")[-1]
        local_model_path = os.path.join(self.bitnet_dir, "models", model_name, filename)
        
        if os.path.exists(local_model_path):
            st.success(f"Found local model: {model_name}")
            return local_model_path
            
        st.error(f"Model file not found at {local_model_path}")
        st.info("The GGUF model must be generated by the 'Initialize Engine' process. Please run it again to download and convert the weights.")
        return None

    def run_inference(self, prompt, model_path):
        """Execute the bitnet binary with the provided prompt."""
        binary = self.get_binary_path()
        if not binary:
            st.error("Inference binary not found. Please re-run Initialization.")
            return None
        
        # Build the command. bitnet binary usually takes -m and -p
        if binary.endswith(".py"):
            cmd = ["python", binary, "-m", model_path, "-p", prompt, "-n", "128"]
        else:
            cmd = [binary, "-m", model_path, "-p", prompt, "-n", "128"]
            
        try:
            # We'll return a Popen object so the app can stream the response
            # CRITICAL: We must set cwd=self.bitnet_dir so run_inference.py can find build/
            process = subprocess.Popen(
                cmd,
                cwd=self.bitnet_dir,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                bufsize=1
            )
            return process
        except Exception as e:
            st.error(f"Inference execution failed: {e}")
            return None