SmolTransform / app.py
TobDeBer's picture
Update app.py
e3370cf verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import time
from threading import Thread
import sys
import os
# os.environ["BNB_CUDA_VERSION"] = "0" # Forces bitsandbytes to recognize no GPU
os.environ["OMP_NUM_THREADS"] = "1" # Prevents race conditions in custom CPU kernels
os.environ["VECLIB_MAXIMUM_ISA"] = "AVX2"
os.environ["MKL_DEBUG_CPU_TYPE"] = "5" # Forces MKL to use AVX2
try:
import spaces
except ImportError:
spaces = None
if spaces is None or not torch.cuda.is_available():
print("Using CPU-only mode (spaces.GPU disabled)")
class SpacesShim:
def GPU(self, *args, **kwargs):
# Helper to handle both @spaces.GPU and @spaces.GPU(duration=...) usage
def decorator(func):
return func
# If called as @spaces.GPU (no parens), the first arg is the function
if len(args) == 1 and callable(args[0]) and not kwargs:
return args[0]
# If called as @spaces.GPU(duration=30), it returns the decorator
return decorator
spaces = SpacesShim()
def gpu_decorator(func):
return spaces.GPU()(func)
# Model configuration
if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
MODEL_NAME = sys.argv[1]
print(f"Using local model from: {MODEL_NAME}")
else:
#MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b80s-0.5"
#MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b60s-0.5"
MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b100-0.5"
#MODEL_NAME = "TobDeBer/SmolLM2-135M-Instruct-hirma-b60s-0.5"
#MODEL_NAME = "TobDeBer/SmolLM2-135M-Instruct-b100"
##MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b60-bnb4"
#MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-b60-0.5"
##MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-q20-bnb8"
##MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-q20"
# MODEL_NAME = "TobDeBer/SmolLM3-3B-hirma-q80-bnb4"
#MODEL_NAME = "TobDeBer/SmolLM2-135M-Instruct-q99-bnb4"
#MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct"
# Global variables
tokenizer = None
model = None
import platform
import subprocess
import cpuinfo # Optional: 'pip install py-cpuinfo' is better if you can add it
def load_model():
"""Load the Smol LLM model and tokenizer with hardware detection"""
global tokenizer, model
try:
print("--- Hardware Audit ---")
print(f"Processor: {platform.processor()}")
print(f"Machine: {platform.machine()}")
# Check for CPU Flags (Instruction Sets)
try:
# For Linux-based Cloud environments
cpu_flags = subprocess.check_output("lscpu", shell=True).decode()
print("Instruction sets found:")
for flag in ["avx512", "avx2", "avx", "fma", "amx"]:
if flag in cpu_flags.lower():
print(f" ✅ {flag.upper()} supported")
else:
print(f" ❌ {flag.upper()} NOT found")
except Exception as e:
print(f"Could not check CPU flags: {e}")
print(f"PyTorch version: {torch.__version__}")
print(f"Loading model: {MODEL_NAME}")
print("----------------------")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Determine device and dtype based on hardware availability
if torch.cuda.is_available():
print(" ✅ CUDA detected. Loading model on GPU.")
device_map = "auto"
dtype = torch.bfloat16
else:
print(" ⚠️ No CUDA detected. Loading model on CPU.")
device_map = {"": "cpu"}
dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
dtype=dtype,
device_map=device_map,
low_cpu_mem_usage=True
)
model.to(torch.bfloat16)
return "✅ Model loaded successfully!"
except Exception as e:
return f"❌ Error loading model: {str(e)}"
@spaces.GPU(duration=30)
def chat_predict(message, history, max_length, temperature, top_p, repetition_penalty, system_prompt):
"""Generate text using the loaded model with streaming and history"""
global model, tokenizer
if model is None or tokenizer is None:
yield "⚠️ Please wait for the model to finish loading..."
return
try:
# Prepare messages for chat template
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Handle history which can be list of dicts with multimodal content
for msg in history:
role = msg.get("role", "user")
content = msg.get("content", "")
# Extract text if content is a list (multimodal format in Gradio 6)
if isinstance(content, list):
text_content = ""
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
text_content += part.get("text", "")
content = text_content
# Ensure content is string
if not isinstance(content, str):
content = str(content)
# Clean up assistant stats
if role == "assistant" and "\n\n---\n*Generated" in content:
content = content.split("\n\n---\n*Generated")[0]
messages.append({"role": role, "content": content})
messages.append({"role": "user", "content": message})
# Format the prompt
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print("formatted_prompt: ", formatted_prompt)
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# Setup streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Generation arguments
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Consume the stream
generated_text = ""
start_time = time.time()
token_count = 0
last_update_time = start_time
current_stats = ""
for new_text in streamer:
generated_text += new_text
token_count += 1
# Update stats every 0.2 seconds
current_time = time.time()
if current_time - last_update_time > 0.2:
elapsed = current_time - start_time
if elapsed > 0:
tps = token_count / elapsed
current_stats = f"\n\n---\n*Generating... ({tps:.1f} t/s)*"
last_update_time = current_time
yield generated_text + current_stats
# Final stats
elapsed_time = time.time() - start_time
if elapsed_time > 0:
tps = token_count / elapsed_time
stats = f"\n\n---\n*Generated {token_count} tokens in {elapsed_time:.2f}s ({tps:.2f} t/s)*"
yield generated_text + stats
except Exception as e:
yield f"❌ Error during generation: {str(e)}"
# Custom CSS to force full height and style chat
css = """
.gradio-container {
height: 100vh !important;
max-height: 100vh !important;
overflow: hidden !important;
}
#main-row {
height: calc(100vh - 150px) !important;
}
#chat-col {
height: 100% !important;
}
/* Thin box around prompt field - targeting specifically within chat column */
#chat-col textarea {
border: 1px solid #64748b !important;
border-radius: 8px !important;
padding: 8px !important;
}
"""
# Create custom theme with smaller base font
custom_theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="indigo",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
text_size="md",
spacing_size="sm",
radius_size="md"
).set(
button_primary_background_fill="*primary_600",
button_primary_background_fill_hover="*primary_700",
block_title_text_weight="600",
)
# Build the Gradio interface
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(
"""
# 🤖 Smol LLM Chat - Multi-turn chat with SmolLM3-3B.
"""
)
with gr.Row(elem_id="main-row"):
with gr.Column(scale=1, min_width=200):
with gr.Accordion("⚙️ Parameters", open=False):
max_tokens = gr.Slider(
minimum=50,
maximum=1024,
value=200,
step=50,
label="Max Tokens"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.1,
step=0.1,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p"
)
repetition_penalty = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.1,
step=0.1,
label="Repetition Penalty"
)
system_prompt = gr.Textbox(
label="System Prompt",
value="You are a helpful AI assistant. Provide clear and concise answers.",
lines=2
)
with gr.Column(scale=4, elem_id="chat-col"):
# Chat Interface
chat_interface = gr.ChatInterface(
fn=chat_predict,
fill_height=True,
additional_inputs=[
max_tokens,
temperature,
top_p,
repetition_penalty,
system_prompt
],
)
# Auto-load the model at startup
load_status = load_model()
print(f"Startup load status: {load_status}")
if __name__ == "__main__":
# Launch the application
demo.launch(
theme=custom_theme,
css=css,
share=False,
show_error=True
)