File size: 11,023 Bytes
093459e
 
002a426
093459e
c88e367
b393dd6
 
a34e50c
fd54d78
22e6bfc
 
093459e
a34e50c
 
 
 
 
7c1c345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a34e50c
7c1c345
a34e50c
ee7dcb0
b393dd6
 
 
 
d304a44
7b75877
e3370cf
87a6815
706397b
ef4154d
565775d
466da76
7582e13
38200dc
5755197
5513c2f
c09542b
ee7dcb0
093459e
 
 
fd54d78
 
 
 
093459e
fd54d78
ee7dcb0
093459e
fd54d78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
093459e
fd54d78
 
56bfe89
fd54d78
56bfe89
 
fd54d78
9f3c282
 
 
 
a909c71
9f3c282
 
 
 
 
093459e
 
9f3c282
 
fd54d78
093459e
a909c71
fd54d78
093459e
 
 
 
b7dc420
ee7dcb0
 
c88e367
093459e
c88e367
 
 
093459e
 
002a426
ee7dcb0
 
 
 
f8491dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee7dcb0
 
 
093459e
ee7dcb0
56bfe89
c88e367
093459e
c88e367
 
 
 
 
 
 
093459e
 
 
 
 
 
c88e367
093459e
 
c88e367
 
 
093459e
c88e367
 
 
 
002a426
 
093459e
c88e367
 
 
002a426
 
 
 
 
 
 
 
 
 
 
ee7dcb0
002a426
ee7dcb0
 
 
 
 
c88e367
093459e
c88e367
093459e
93e5181
88bf6aa
 
 
aee6552
 
88bf6aa
 
aee6552
88bf6aa
 
 
 
93e5181
 
 
 
 
 
88bf6aa
 
93e5181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
093459e
2ac72fa
093459e
 
632bd5b
093459e
 
 
88bf6aa
36fe5ed
 
 
 
 
 
 
 
 
 
 
 
eab4117
36fe5ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88bf6aa
36fe5ed
 
 
 
 
 
 
 
 
 
 
cab2b06
093459e
0bd8d77
0c489f4
 
0bd8d77
002a426
 
 
2ac72fa
 
002a426
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import time
from threading import Thread
import sys
import os
# os.environ["BNB_CUDA_VERSION"] = "0" # Forces bitsandbytes to recognize no GPU
os.environ["OMP_NUM_THREADS"] = "1" # Prevents race conditions in custom CPU kernels
os.environ["VECLIB_MAXIMUM_ISA"] = "AVX2" 
os.environ["MKL_DEBUG_CPU_TYPE"] = "5" # Forces MKL to use AVX2

try:
    import spaces
except ImportError:
    spaces = None

if spaces is None or not torch.cuda.is_available():
    print("Using CPU-only mode (spaces.GPU disabled)")
    class SpacesShim:
        def GPU(self, *args, **kwargs):
            # Helper to handle both @spaces.GPU and @spaces.GPU(duration=...) usage
            def decorator(func):
                return func
            
            # If called as @spaces.GPU (no parens), the first arg is the function
            if len(args) == 1 and callable(args[0]) and not kwargs:
                return args[0]
            
            # If called as @spaces.GPU(duration=30), it returns the decorator
            return decorator
            
    spaces = SpacesShim()

def gpu_decorator(func):
    return spaces.GPU()(func)

# Model configuration
if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
    MODEL_NAME = sys.argv[1]
    print(f"Using local model from: {MODEL_NAME}")
else:
    #MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b80s-0.5"
    #MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b60s-0.5"
    MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b100-0.5"
    #MODEL_NAME = "TobDeBer/SmolLM2-135M-Instruct-hirma-b60s-0.5"
    #MODEL_NAME = "TobDeBer/SmolLM2-135M-Instruct-b100"
    ##MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b60-bnb4"
    #MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b60-0.5"
    ##MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-q20-bnb8"
    ##MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-q20"
    # MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-q80-bnb4"
    #MODEL_NAME = "TobDeBer/SmolLM2-135M-Instruct-q99-bnb4"
    #MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct"
    
# Global variables
tokenizer = None
model = None

import platform
import subprocess
import cpuinfo  # Optional: 'pip install py-cpuinfo' is better if you can add it

def load_model():
    """Load the Smol LLM model and tokenizer with hardware detection"""
    global tokenizer, model
    try:
        print("--- Hardware Audit ---")
        print(f"Processor: {platform.processor()}")
        print(f"Machine: {platform.machine()}")
        
        # Check for CPU Flags (Instruction Sets)
        try:
            # For Linux-based Cloud environments
            cpu_flags = subprocess.check_output("lscpu", shell=True).decode()
            print("Instruction sets found:")
            for flag in ["avx512", "avx2", "avx", "fma", "amx"]:
                if flag in cpu_flags.lower():
                    print(f"  ✅ {flag.upper()} supported")
                else:
                    print(f"  ❌ {flag.upper()} NOT found")
        except Exception as e:
            print(f"Could not check CPU flags: {e}")
        
        print(f"PyTorch version: {torch.__version__}")
        print(f"Loading model: {MODEL_NAME}")
        print("----------------------")

        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
        tokenizer.padding_side = "left"
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Determine device and dtype based on hardware availability
        if torch.cuda.is_available():
            print("  ✅ CUDA detected. Loading model on GPU.")
            device_map = "auto"
            dtype = torch.bfloat16
        else:
            print("  ⚠️ No CUDA detected. Loading model on CPU.")
            device_map = {"": "cpu"}
            dtype = torch.float32

        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            dtype=dtype,
            device_map=device_map,
            low_cpu_mem_usage=True
        )
        model.to(torch.bfloat16)

        return "✅ Model loaded successfully!"
    except Exception as e:
        return f"❌ Error loading model: {str(e)}"

@spaces.GPU(duration=30)
def chat_predict(message, history, max_length, temperature, top_p, repetition_penalty, system_prompt):
    """Generate text using the loaded model with streaming and history"""
    global model, tokenizer
    
    if model is None or tokenizer is None:
        yield "⚠️ Please wait for the model to finish loading..."
        return
    
    try:
        # Prepare messages for chat template
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
            
        # Handle history which can be list of dicts with multimodal content
        for msg in history:
            role = msg.get("role", "user")
            content = msg.get("content", "")
            
            # Extract text if content is a list (multimodal format in Gradio 6)
            if isinstance(content, list):
                text_content = ""
                for part in content:
                    if isinstance(part, dict) and part.get("type") == "text":
                        text_content += part.get("text", "")
                content = text_content
            
            # Ensure content is string
            if not isinstance(content, str):
                content = str(content)

            # Clean up assistant stats
            if role == "assistant" and "\n\n---\n*Generated" in content:
                content = content.split("\n\n---\n*Generated")[0]
                
            messages.append({"role": role, "content": content})
            
        messages.append({"role": "user", "content": message})
        
        # Format the prompt
        formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        print("formatted_prompt: ", formatted_prompt)
        inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
        
        # Setup streamer
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        
        # Generation arguments
        generation_kwargs = dict(
            **inputs,
            streamer=streamer,
            max_new_tokens=max_length,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
        
        # Start generation in a separate thread
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()
        
        # Consume the stream
        generated_text = ""
        start_time = time.time()
        token_count = 0
        last_update_time = start_time
        current_stats = ""
        
        for new_text in streamer:
            generated_text += new_text
            token_count += 1
            
            # Update stats every 0.2 seconds
            current_time = time.time()
            if current_time - last_update_time > 0.2:
                elapsed = current_time - start_time
                if elapsed > 0:
                    tps = token_count / elapsed
                    current_stats = f"\n\n---\n*Generating... ({tps:.1f} t/s)*"
                last_update_time = current_time
                
            yield generated_text + current_stats

        # Final stats
        elapsed_time = time.time() - start_time
        if elapsed_time > 0:
            tps = token_count / elapsed_time
            stats = f"\n\n---\n*Generated {token_count} tokens in {elapsed_time:.2f}s ({tps:.2f} t/s)*"
            yield generated_text + stats
            
    except Exception as e:
        yield f"❌ Error during generation: {str(e)}"

# Custom CSS to force full height and style chat
css = """
.gradio-container {
    height: 100vh !important;
    max-height: 100vh !important;
    overflow: hidden !important;
}
#main-row {
    height: calc(100vh - 150px) !important;
}
#chat-col {
    height: 100% !important;
}
/* Thin box around prompt field - targeting specifically within chat column */
#chat-col textarea {
    border: 1px solid #64748b !important;
    border-radius: 8px !important;
    padding: 8px !important;
}
"""

# Create custom theme with smaller base font
custom_theme = gr.themes.Soft(
    primary_hue="blue",
    secondary_hue="indigo",
    neutral_hue="slate",
    font=gr.themes.GoogleFont("Inter"),
    text_size="md",
    spacing_size="sm",
    radius_size="md"
).set(
    button_primary_background_fill="*primary_600",
    button_primary_background_fill_hover="*primary_700",
    block_title_text_weight="600",
)

# Build the Gradio interface
with gr.Blocks(fill_height=True) as demo:
    gr.Markdown(
        """
        # 🤖 Smol LLM Chat - Multi-turn chat with SmolLM3-3B.
        """
    )
    
    with gr.Row(elem_id="main-row"):
        with gr.Column(scale=1, min_width=200):
            with gr.Accordion("⚙️ Parameters", open=False):
                max_tokens = gr.Slider(
                    minimum=50,
                    maximum=1024,
                    value=200,
                    step=50,
                    label="Max Tokens"
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=0.1,
                    step=0.1,
                    label="Temperature"
                )
                top_p = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.95,
                    step=0.05,
                    label="Top-p"
                )
                repetition_penalty = gr.Slider(
                    minimum=1.0,
                    maximum=2.0,
                    value=1.1,
                    step=0.1,
                    label="Repetition Penalty"
                )
                system_prompt = gr.Textbox(
                    label="System Prompt",
                    value="You are a helpful AI assistant. Provide clear and concise answers.",
                    lines=2
                )

        with gr.Column(scale=4, elem_id="chat-col"):
            # Chat Interface
            chat_interface = gr.ChatInterface(
                fn=chat_predict,
                fill_height=True,
                additional_inputs=[
                    max_tokens,
                    temperature,
                    top_p,
                    repetition_penalty,
                    system_prompt
                ],
            )

# Auto-load the model at startup
load_status = load_model()
print(f"Startup load status: {load_status}")

if __name__ == "__main__":
    # Launch the application
    demo.launch(
        theme=custom_theme,
        css=css,
        share=False,
        show_error=True
    )