slm-testing / app.py
treerats88's picture
Add model switching interruption and support for reasoning model tokens
2563114 verified
raw
history blame
11.5 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
from threading import Thread
import gc
import os
import shutil
import torch
import psutil
import time
# Define path for HF cache to clean
HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub")
# List of models for autocomplete
MODELS = [
'HuggingFaceTB/SmolLM2-135M', 'AxiomicLabs/GPT-X2-125M', 'Qwen/Qwen3-0.6B',
'facebook/MobileLLM-R1-140M-base', 'SupraLabs/Supra-50M-Base', 'CompactAI-O/Shard-1',
'SupraLabs/Supra-50M-Instruct', 'HuggingFaceTB/SmolLM-135M', 'facebook/opt-125m',
'AxiomicLabs/GPT-S-5M', 'openai-community/gpt2', 'LH-Tech-AI/Spark-5M-Base-v4',
'SupraLabs/Supra-Mini-v5-8M', 'EleutherAI/pythia-70m', 'SupraLabs/Supra-Mini-v4-2M',
'EleutherAI/pythia-31m', 'StentorLabs/Stentor3-50M', 'StentorLabs/Stentor3-20M',
'StentorLabs/Portimbria-150M', 'HuggingFaceTB/nanowhale-100m-base', 'EleutherAI/pythia-14m',
'Harley-ml/Tenete-8M', 'Harley-ml/Dillion-1.2M', 'MihaiPopa-1/CinnabarLM-1.4M-Base',
'MihaiPopa-1/CinnabarLM-4M-Base', 'MihaiPopa-1/PotentSulfurLM-500K-Base',
'MihaiPopa-1/CinnabarLM-1.5M-Base', 'Harley-ml/Dillionv2-1.3M', 'Eclipse-Senpai/KeyLM-75M',
'SupraLabs/Supra-Mini-v6-1M', 'AxiomicLabs/GPT-S-1.4M', 'GODELEV/Archaea-74M',
'Sandroeth/cali-0.1B', 'veyra-ai/veyra3-5m-base', 'veyra-ai/veyra-30m-base-5b-tokens',
'ThingAI/Quark-50m', 'ThingAI/Quark-135m', 'HuggingFaceTB/SmolLM2-135M-Instruct',
'Aravindan/awesome-gpt-2-coder', 'Qwen/Qwen2.5-Coder-0.5B', 'SupraLabs/Supra-50M-Reasoning'
]
ACTIVE_SESSIONS = {}
SESSION_TIMEOUT = 60
def live_count(request: gr.Request):
current_time = time.time()
if request:
ACTIVE_SESSIONS[request.session_hash] = current_time
# Prune
expired = [s for s, t in ACTIVE_SESSIONS.items() if current_time - t > SESSION_TIMEOUT]
for s in expired:
ACTIVE_SESSIONS.pop(s, None)
return len(ACTIVE_SESSIONS)
# Global class to safely manage the loaded model and tokenizer in memory
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
self.model_id = None
self.stop_generation = False # Added flag to instantly kill generation
self.device = "cuda" if torch.cuda.is_available() else "cpu"
model_manager = ModelManager()
# Custom stopping criteria to halt the generation thread when loading a new model
class StopOnFlag(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return model_manager.stop_generation
def get_system_stats(request: gr.Request = None):
"""Returns a dictionary of current system metrics with formatted strings."""
mem = psutil.virtual_memory()
disk = psutil.disk_usage('/')
return (
f"CPU\t\t: \t{psutil.cpu_percent(interval=1)}%\n"
f"Mem\t\t: \t{round(mem.used / (1024**3), 2)} / {round(mem.total / (1024**3), 2)} GB\n"
f"Disk\t\t: \t{round(disk.used / (1024**3), 2)} / {round(disk.total / (1024**3), 2)} GB\n"
f"Active\t: \t{len(ACTIVE_SESSIONS) if request is None else live_count(request)} session(s)"
)
def load_new_model(model_id):
"""Loads the model and tokenizer dynamically into the global manager."""
# Stop any ongoing generation immediately
model_manager.stop_generation = True
# Clear old model from memory
model_manager.model = None
model_manager.tokenizer = None
model_manager.model_id = None
yield f"Loading {model_id}..."
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
try:
# Load explicitly for streaming purposes instead of pipeline
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(model_manager.device)
model_manager.tokenizer = tokenizer
model_manager.model = model
model_manager.model_id = model_id
yield f"Successfully loaded {model_id} on {model_manager.device.upper()}"
except Exception as e:
yield f"Error loading model: {str(e)}"
def run_inference(user_prompt, max_tokens, temperature, top_k, top_p, rep_penalty, ngram_size, do_sample):
"""Generates text via streaming generator."""
if model_manager.model is None or model_manager.tokenizer is None:
yield "Please load a model first.", "Model not loaded"
return
# Reset the stop flag for the new generation run
model_manager.stop_generation = False
tokenizer = model_manager.tokenizer
model = model_manager.model
model_id = model_manager.model_id
is_supra_reasoning = "Supra-50M-Reasoning" in model_id if model_id else False
if is_supra_reasoning:
SYSTEM_PROMPT = "Your role as an assistant involves thoroughly exploring questions through a systematic long thinking process before providing the final precise and accurate solutions."
prompt_to_encode = (
f"[SYSTEM]: {SYSTEM_PROMPT}\n\n"
f"[USER]: {user_prompt}\n\n"
f"[ASSISTANT]: <|begin_of_thought|>\n"
)
skip_special = False
else:
prompt_to_encode = user_prompt
skip_special = True
# Tokenize input
inputs = tokenizer([prompt_to_encode], return_tensors="pt").to(model_manager.device)
# Set up the streamer
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=skip_special)
# Adjust variables based on the do_sample logic
if not do_sample:
temperature = 1.0 # Temperature is ignored if do_sample=False, but setting it > 0 avoids config errors
# Generation arguments
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
repetition_penalty=float(rep_penalty),
no_repeat_ngram_size=int(ngram_size),
do_sample=do_sample,
pad_token_id=tokenizer.eos_token_id, # Prevents padding warnings
stopping_criteria=StoppingCriteriaList([StopOnFlag()]) # Attach the stopping criteria
)
start_time = time.time()
# Start generation in a separate background thread
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
if is_supra_reasoning:
# Use plain text formatting rather than markdown symbols inside gr.Textbox
base_display = f"Prompt: {user_prompt}\n\n----------------------------------------\n\n"
generated_text = ""
else:
base_display = ""
generated_text = user_prompt
# Yield output iteratively for the streaming effect
token_count = 0
for new_text in streamer:
# Immediately break out of the UI update loop if a new model is loaded
if model_manager.stop_generation:
break
generated_text += new_text
token_count += 1
duration = time.time() - start_time
tps = token_count / duration if duration > 0 else 0
display_text = generated_text
if is_supra_reasoning:
display_text = display_text.replace("<s>", "").replace("</s>", "")
if not display_text.startswith("🧠 Thinking Process:"):
display_text = "🧠 Thinking Process:\n" + display_text
display_text = display_text.replace("<|begin_of_thought|>", "🧠 Thinking Process:\n")
display_text = display_text.replace("<|end_of_thought|>", "\n\n")
display_text = display_text.replace("<|begin_of_solution|>", "✅ Final Answer:\n\n")
display_text = display_text.replace("<|end_of_solution|>", "")
yield base_display + display_text, f"Speed: {tps:.2f} tokens/sec"
def clean_cache():
if os.path.exists(HF_CACHE_DIR):
shutil.rmtree(HF_CACHE_DIR)
os.makedirs(HF_CACHE_DIR)
return "Cache cleaned successfully!"
return "Cache directory not found."
# Gradio Interface
with gr.Blocks(title="Small MF Model Tester", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🚀 Small Model Evaluation Hub with Streaming")
with gr.Row():
# Left column: Settings & Monitoring
with gr.Column(scale=1):
with gr.Accordion("System Monitoring", open=True):
stats_output = gr.Textbox(label="Live System Stats", show_label=False)
gr.Timer(2).tick(get_system_stats, None, stats_output)
with gr.Group():
gr.Markdown("### Select or Paste custom model id here")
with gr.Row():
model_id_input = gr.Dropdown(choices=MODELS, label="Model", allow_custom_value=True, show_label=False, scale=3)
load_btn = gr.Button("Load", variant="secondary", scale=1)
clean_btn = gr.Button("Clean HF Cache", variant="stop", size="sm")
with gr.Accordion("Generation Configuration", open=False):
do_sample_input = gr.Checkbox(label="Enable Sampling (do_sample)", value=True, info="Uncheck for greedy decoding")
max_tokens_input = gr.Slider(minimum=10, maximum=2048, value=256, step=1, label="Max Output Tokens")
temperature_input = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Higher = more creative")
top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="0 = disabled")
top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus)", info="1.0 = disabled")
rep_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="1.0 = disabled")
ngram_size_input = gr.Slider(minimum=0, maximum=10, value=0, step=1, label="No Repeat N-Gram Size", info="0 = disabled")
# Right column: Interaction
with gr.Column(scale=2):
user_prompt = gr.Textbox(
label="Prompt",
value="Once upon a time in a digital kingdom,",
placeholder="Enter your prompt here...",
lines=5
)
run_btn = gr.Button("Generate text", variant="primary", size="lg")
status_output = gr.Markdown("Status: *Waiting to load model...*")
output_text = gr.Textbox(label="Result", lines=15, buttons=["copy"], autoscroll=True)
# Events
load_btn.click(
fn=load_new_model,
inputs=[model_id_input],
outputs=[status_output]
)
# We use `.click` targeting a generator function, which Gradio naturally treats as a streaming output
run_btn.click(
fn=run_inference,
inputs=[
user_prompt,
max_tokens_input,
temperature_input,
top_k_input,
top_p_input,
rep_penalty_input,
ngram_size_input,
do_sample_input
],
outputs=[output_text, status_output]
)
clean_btn.click(fn=clean_cache, outputs=[status_output])
if __name__ == "__main__":
app.launch()