slm-testing / app.py
stanley-00's picture
Update app.py
473340d verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import gc
import os
import shutil
import torch
import psutil
import time
# Define path for HF cache to clean
HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub")
# List of models for autocomplete
MODELS = [
'HuggingFaceTB/SmolLM2-135M', 'AxiomicLabs/GPT-X2-125M', 'Qwen/Qwen3-0.6B',
'facebook/MobileLLM-R1-140M-base', 'SupraLabs/Supra-50M-Base', 'CompactAI-O/Shard-1',
'SupraLabs/Supra-50M-Instruct', 'HuggingFaceTB/SmolLM-135M', 'facebook/opt-125m',
'AxiomicLabs/GPT-S-5M', 'openai-community/gpt2', 'LH-Tech-AI/Spark-5M-Base-v4',
'SupraLabs/Supra-Mini-v5-8M', 'EleutherAI/pythia-70m', 'SupraLabs/Supra-Mini-v4-2M',
'EleutherAI/pythia-31m', 'StentorLabs/Stentor3-50M', 'StentorLabs/Stentor3-20M',
'StentorLabs/Portimbria-150M', 'HuggingFaceTB/nanowhale-100m-base', 'EleutherAI/pythia-14m',
'Harley-ml/Tenete-8M', 'Harley-ml/Dillion-1.2M', 'MihaiPopa-1/CinnabarLM-1.4M-Base',
'MihaiPopa-1/CinnabarLM-4M-Base', 'MihaiPopa-1/PotentSulfurLM-500K-Base',
'MihaiPopa-1/CinnabarLM-1.5M-Base', 'Harley-ml/Dillionv2-1.3M', 'Eclipse-Senpai/KeyLM-75M',
'SupraLabs/Supra-Mini-v6-1M', 'AxiomicLabs/GPT-S-1.4M', 'GODELEV/Archaea-74M',
'Sandroeth/cali-0.1B', 'veyra-ai/veyra3-5m-base', 'veyra-ai/veyra-30m-base-5b-tokens',
'ThingAI/Quark-50m', 'ThingAI/Quark-135m', 'HuggingFaceTB/SmolLM2-135M-Instruct',
'Aravindan/awesome-gpt-2-coder', 'Qwen/Qwen2.5-Coder-0.5B', 'SupraLabs/Supra-50M-Reasoning'
]
ACTIVE_SESSIONS = {}
SESSION_TIMEOUT = 60
def live_count(request: gr.Request):
current_time = time.time()
if request:
ACTIVE_SESSIONS[request.session_hash] = current_time
# Prune
expired = [s for s, t in ACTIVE_SESSIONS.items() if current_time - t > SESSION_TIMEOUT]
for s in expired:
ACTIVE_SESSIONS.pop(s, None)
return len(ACTIVE_SESSIONS)
# Global class to safely manage the loaded model and tokenizer in memory
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
model_manager = ModelManager()
def get_system_stats(request: gr.Request = None):
"""Returns a dictionary of current system metrics with formatted strings."""
mem = psutil.virtual_memory()
disk = psutil.disk_usage('/')
return (
f"CPU\t\t: \t{psutil.cpu_percent(interval=1)}%\n"
f"Mem\t\t: \t{round(mem.used / (1024**3), 2)} / {round(mem.total / (1024**3), 2)} GB\n"
f"Disk\t\t: \t{round(disk.used / (1024**3), 2)} / {round(disk.total / (1024**3), 2)} GB\n"
f"Active\t: \t{len(ACTIVE_SESSIONS) if request is None else live_count(request)} session(s)"
)
def load_new_model(model_id):
"""Loads the model and tokenizer dynamically into the global manager."""
# Clear old model from memory
model_manager.model = None
model_manager.tokenizer = None
yield f"Loading {model_id}..."
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
try:
# Load explicitly for streaming purposes instead of pipeline
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(model_manager.device)
model_manager.tokenizer = tokenizer
model_manager.model = model
yield f"Successfully loaded {model_id} on {model_manager.device.upper()}"
except Exception as e:
yield f"Error loading model: {str(e)}"
def run_inference(user_prompt, max_tokens, temperature, top_k, top_p, rep_penalty, ngram_size, do_sample):
"""Generates text via streaming generator."""
if model_manager.model is None or model_manager.tokenizer is None:
yield "Please load a model first.", "Model not loaded"
return
tokenizer = model_manager.tokenizer
model = model_manager.model
# Tokenize input
inputs = tokenizer([user_prompt], return_tensors="pt").to(model_manager.device)
# Set up the streamer
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
# Adjust variables based on the do_sample logic
if not do_sample:
temperature = 1.0 # Temperature is ignored if do_sample=False, but setting it > 0 avoids config errors
# Generation arguments
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
repetition_penalty=float(rep_penalty),
no_repeat_ngram_size=int(ngram_size),
do_sample=do_sample,
pad_token_id=tokenizer.eos_token_id # Prevents padding warnings
)
start_time = time.time()
# Start generation in a separate background thread
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# Yield output iteratively for the streaming effect
generated_text = user_prompt
token_count = 0
for new_text in streamer:
generated_text += new_text
token_count += 1
duration = time.time() - start_time
tps = token_count / duration if duration > 0 else 0
yield generated_text, f"Speed: {tps:.2f} tokens/sec"
def clean_cache():
if os.path.exists(HF_CACHE_DIR):
shutil.rmtree(HF_CACHE_DIR)
os.makedirs(HF_CACHE_DIR)
return "Cache cleaned successfully!"
return "Cache directory not found."
# Gradio Interface
with gr.Blocks(title="Small MF Model Tester", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🚀 Small Model Evaluation Hub with Streaming")
with gr.Row():
# Left column: Settings & Monitoring
with gr.Column(scale=1):
with gr.Accordion("System Monitoring", open=True):
stats_output = gr.Textbox(label="Live System Stats", show_label=False)
gr.Timer(2).tick(get_system_stats, None, stats_output)
with gr.Group():
gr.Markdown("### Select or Paste custom model id here")
with gr.Row():
model_id_input = gr.Dropdown(choices=MODELS, label="Model", allow_custom_value=True, show_label=False, scale=3)
load_btn = gr.Button("Load", variant="secondary", scale=1)
clean_btn = gr.Button("Clean HF Cache", variant="stop", size="sm")
with gr.Accordion("Generation Configuration", open=False):
do_sample_input = gr.Checkbox(label="Enable Sampling (do_sample)", value=True, info="Uncheck for greedy decoding")
max_tokens_input = gr.Slider(minimum=10, maximum=2048, value=256, step=1, label="Max Output Tokens")
temperature_input = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Higher = more creative")
top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="0 = disabled")
top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus)", info="1.0 = disabled")
rep_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="1.0 = disabled")
ngram_size_input = gr.Slider(minimum=0, maximum=10, value=0, step=1, label="No Repeat N-Gram Size", info="0 = disabled")
# Right column: Interaction
with gr.Column(scale=2):
user_prompt = gr.Textbox(
label="Prompt",
value="Once upon a time in a digital kingdom,",
placeholder="Enter your prompt here...",
lines=5
)
run_btn = gr.Button("Generate text", variant="primary", size="lg")
status_output = gr.Markdown("Status: *Waiting to load model...*")
output_text = gr.Textbox(label="Result", lines=15, buttons=["copy"], autoscroll=True)
# Events
load_btn.click(
fn=load_new_model,
inputs=[model_id_input],
outputs=[status_output]
)
# We use `.click` targeting a generator function, which Gradio naturally treats as a streaming output
run_btn.click(
fn=run_inference,
inputs=[
user_prompt,
max_tokens_input,
temperature_input,
top_k_input,
top_p_input,
rep_penalty_input,
ngram_size_input,
do_sample_input
],
outputs=[output_text, status_output]
)
clean_btn.click(fn=clean_cache, outputs=[status_output])
if __name__ == "__main__":
app.launch()