Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| import time | |
| from threading import Thread | |
| import sys | |
| import os | |
| # os.environ["BNB_CUDA_VERSION"] = "0" # Forces bitsandbytes to recognize no GPU | |
| os.environ["OMP_NUM_THREADS"] = "1" # Prevents race conditions in custom CPU kernels | |
| os.environ["VECLIB_MAXIMUM_ISA"] = "AVX2" | |
| os.environ["MKL_DEBUG_CPU_TYPE"] = "5" # Forces MKL to use AVX2 | |
| try: | |
| import spaces | |
| except ImportError: | |
| spaces = None | |
| if spaces is None or not torch.cuda.is_available(): | |
| print("Using CPU-only mode (spaces.GPU disabled)") | |
| class SpacesShim: | |
| def GPU(self, *args, **kwargs): | |
| # Helper to handle both @spaces.GPU and @spaces.GPU(duration=...) usage | |
| def decorator(func): | |
| return func | |
| # If called as @spaces.GPU (no parens), the first arg is the function | |
| if len(args) == 1 and callable(args[0]) and not kwargs: | |
| return args[0] | |
| # If called as @spaces.GPU(duration=30), it returns the decorator | |
| return decorator | |
| spaces = SpacesShim() | |
| def gpu_decorator(func): | |
| return spaces.GPU()(func) | |
| # Model configuration | |
| if len(sys.argv) > 1 and os.path.exists(sys.argv[1]): | |
| MODEL_NAME = sys.argv[1] | |
| print(f"Using local model from: {MODEL_NAME}") | |
| else: | |
| #MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b80s-0.5" | |
| #MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b60s-0.5" | |
| MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b100-0.5" | |
| #MODEL_NAME = "TobDeBer/SmolLM2-135M-Instruct-hirma-b60s-0.5" | |
| #MODEL_NAME = "TobDeBer/SmolLM2-135M-Instruct-b100" | |
| ##MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b60-bnb4" | |
| #MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b60-0.5" | |
| ##MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-q20-bnb8" | |
| ##MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-q20" | |
| # MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-q80-bnb4" | |
| #MODEL_NAME = "TobDeBer/SmolLM2-135M-Instruct-q99-bnb4" | |
| #MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct" | |
| # Global variables | |
| tokenizer = None | |
| model = None | |
| import platform | |
| import subprocess | |
| import cpuinfo # Optional: 'pip install py-cpuinfo' is better if you can add it | |
| def load_model(): | |
| """Load the Smol LLM model and tokenizer with hardware detection""" | |
| global tokenizer, model | |
| try: | |
| print("--- Hardware Audit ---") | |
| print(f"Processor: {platform.processor()}") | |
| print(f"Machine: {platform.machine()}") | |
| # Check for CPU Flags (Instruction Sets) | |
| try: | |
| # For Linux-based Cloud environments | |
| cpu_flags = subprocess.check_output("lscpu", shell=True).decode() | |
| print("Instruction sets found:") | |
| for flag in ["avx512", "avx2", "avx", "fma", "amx"]: | |
| if flag in cpu_flags.lower(): | |
| print(f" ✅ {flag.upper()} supported") | |
| else: | |
| print(f" ❌ {flag.upper()} NOT found") | |
| except Exception as e: | |
| print(f"Could not check CPU flags: {e}") | |
| print(f"PyTorch version: {torch.__version__}") | |
| print(f"Loading model: {MODEL_NAME}") | |
| print("----------------------") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
| tokenizer.padding_side = "left" | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Determine device and dtype based on hardware availability | |
| if torch.cuda.is_available(): | |
| print(" ✅ CUDA detected. Loading model on GPU.") | |
| device_map = "auto" | |
| dtype = torch.bfloat16 | |
| else: | |
| print(" ⚠️ No CUDA detected. Loading model on CPU.") | |
| device_map = {"": "cpu"} | |
| dtype = torch.float32 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| dtype=dtype, | |
| device_map=device_map, | |
| low_cpu_mem_usage=True | |
| ) | |
| model.to(torch.bfloat16) | |
| return "✅ Model loaded successfully!" | |
| except Exception as e: | |
| return f"❌ Error loading model: {str(e)}" | |
| def chat_predict(message, history, max_length, temperature, top_p, repetition_penalty, system_prompt): | |
| """Generate text using the loaded model with streaming and history""" | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| yield "⚠️ Please wait for the model to finish loading..." | |
| return | |
| try: | |
| # Prepare messages for chat template | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # Handle history which can be list of dicts with multimodal content | |
| for msg in history: | |
| role = msg.get("role", "user") | |
| content = msg.get("content", "") | |
| # Extract text if content is a list (multimodal format in Gradio 6) | |
| if isinstance(content, list): | |
| text_content = "" | |
| for part in content: | |
| if isinstance(part, dict) and part.get("type") == "text": | |
| text_content += part.get("text", "") | |
| content = text_content | |
| # Ensure content is string | |
| if not isinstance(content, str): | |
| content = str(content) | |
| # Clean up assistant stats | |
| if role == "assistant" and "\n\n---\n*Generated" in content: | |
| content = content.split("\n\n---\n*Generated")[0] | |
| messages.append({"role": role, "content": content}) | |
| messages.append({"role": "user", "content": message}) | |
| # Format the prompt | |
| formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| print("formatted_prompt: ", formatted_prompt) | |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
| # Setup streamer | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| # Generation arguments | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # Start generation in a separate thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Consume the stream | |
| generated_text = "" | |
| start_time = time.time() | |
| token_count = 0 | |
| last_update_time = start_time | |
| current_stats = "" | |
| for new_text in streamer: | |
| generated_text += new_text | |
| token_count += 1 | |
| # Update stats every 0.2 seconds | |
| current_time = time.time() | |
| if current_time - last_update_time > 0.2: | |
| elapsed = current_time - start_time | |
| if elapsed > 0: | |
| tps = token_count / elapsed | |
| current_stats = f"\n\n---\n*Generating... ({tps:.1f} t/s)*" | |
| last_update_time = current_time | |
| yield generated_text + current_stats | |
| # Final stats | |
| elapsed_time = time.time() - start_time | |
| if elapsed_time > 0: | |
| tps = token_count / elapsed_time | |
| stats = f"\n\n---\n*Generated {token_count} tokens in {elapsed_time:.2f}s ({tps:.2f} t/s)*" | |
| yield generated_text + stats | |
| except Exception as e: | |
| yield f"❌ Error during generation: {str(e)}" | |
| # Custom CSS to force full height and style chat | |
| css = """ | |
| .gradio-container { | |
| height: 100vh !important; | |
| max-height: 100vh !important; | |
| overflow: hidden !important; | |
| } | |
| #main-row { | |
| height: calc(100vh - 150px) !important; | |
| } | |
| #chat-col { | |
| height: 100% !important; | |
| } | |
| /* Thin box around prompt field - targeting specifically within chat column */ | |
| #chat-col textarea { | |
| border: 1px solid #64748b !important; | |
| border-radius: 8px !important; | |
| padding: 8px !important; | |
| } | |
| """ | |
| # Create custom theme with smaller base font | |
| custom_theme = gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="indigo", | |
| neutral_hue="slate", | |
| font=gr.themes.GoogleFont("Inter"), | |
| text_size="md", | |
| spacing_size="sm", | |
| radius_size="md" | |
| ).set( | |
| button_primary_background_fill="*primary_600", | |
| button_primary_background_fill_hover="*primary_700", | |
| block_title_text_weight="600", | |
| ) | |
| # Build the Gradio interface | |
| with gr.Blocks(fill_height=True) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🤖 Smol LLM Chat - Multi-turn chat with SmolLM3-3B. | |
| """ | |
| ) | |
| with gr.Row(elem_id="main-row"): | |
| with gr.Column(scale=1, min_width=200): | |
| with gr.Accordion("⚙️ Parameters", open=False): | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=1024, | |
| value=200, | |
| step=50, | |
| label="Max Tokens" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.1, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p" | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.1, | |
| label="Repetition Penalty" | |
| ) | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| value="You are a helpful AI assistant. Provide clear and concise answers.", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=4, elem_id="chat-col"): | |
| # Chat Interface | |
| chat_interface = gr.ChatInterface( | |
| fn=chat_predict, | |
| fill_height=True, | |
| additional_inputs=[ | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| repetition_penalty, | |
| system_prompt | |
| ], | |
| ) | |
| # Auto-load the model at startup | |
| load_status = load_model() | |
| print(f"Startup load status: {load_status}") | |
| if __name__ == "__main__": | |
| # Launch the application | |
| demo.launch( | |
| theme=custom_theme, | |
| css=css, | |
| share=False, | |
| show_error=True | |
| ) | |