Spaces:
Running
Running
File size: 8,936 Bytes
cc8c521 ee07f77 cc8c521 2962bbd cc8c521 55c8db2 473340d 55c8db2 32f15bb f7b504f 55c8db2 a1baed0 4692774 a1baed0 ee07f77 b685633 cc8c521 a1baed0 109ea8d 4692774 13c3fc6 cc8c521 ee07f77 cc8c521 ee07f77 32f15bb cc8c521 ee07f77 32f15bb cc8c521 32f15bb cc8c521 ee07f77 2962bbd ee07f77 cc8c521 ee07f77 134a4fc ee07f77 41aa821 ee07f77 41aa821 2962bbd ee07f77 2962bbd ee07f77 2962bbd cc8c521 ee07f77 cc8c521 ee07f77 cc8c521 ee07f77 3238ae1 ee07f77 41aa821 20ab2a9 ee07f77 41aa821 ee07f77 a0adfc8 ee07f77 9e4c82c ee07f77 20ab2a9 ee07f77 9e4c82c 3238ae1 ee07f77 a0adfc8 ee07f77 0e2a4a3 a817aa1 0e2a4a3 cc8c521 ee07f77 cc8c521 ee07f77 cc8c521 ee07f77 a817aa1 cc8c521 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import gc
import os
import shutil
import torch
import psutil
import time
# Define path for HF cache to clean
HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub")
# List of models for autocomplete
MODELS = [
'HuggingFaceTB/SmolLM2-135M', 'AxiomicLabs/GPT-X2-125M', 'Qwen/Qwen3-0.6B',
'facebook/MobileLLM-R1-140M-base', 'SupraLabs/Supra-50M-Base', 'CompactAI-O/Shard-1',
'SupraLabs/Supra-50M-Instruct', 'HuggingFaceTB/SmolLM-135M', 'facebook/opt-125m',
'AxiomicLabs/GPT-S-5M', 'openai-community/gpt2', 'LH-Tech-AI/Spark-5M-Base-v4',
'SupraLabs/Supra-Mini-v5-8M', 'EleutherAI/pythia-70m', 'SupraLabs/Supra-Mini-v4-2M',
'EleutherAI/pythia-31m', 'StentorLabs/Stentor3-50M', 'StentorLabs/Stentor3-20M',
'StentorLabs/Portimbria-150M', 'HuggingFaceTB/nanowhale-100m-base', 'EleutherAI/pythia-14m',
'Harley-ml/Tenete-8M', 'Harley-ml/Dillion-1.2M', 'MihaiPopa-1/CinnabarLM-1.4M-Base',
'MihaiPopa-1/CinnabarLM-4M-Base', 'MihaiPopa-1/PotentSulfurLM-500K-Base',
'MihaiPopa-1/CinnabarLM-1.5M-Base', 'Harley-ml/Dillionv2-1.3M', 'Eclipse-Senpai/KeyLM-75M',
'SupraLabs/Supra-Mini-v6-1M', 'AxiomicLabs/GPT-S-1.4M', 'GODELEV/Archaea-74M',
'Sandroeth/cali-0.1B', 'veyra-ai/veyra3-5m-base', 'veyra-ai/veyra-30m-base-5b-tokens',
'ThingAI/Quark-50m', 'ThingAI/Quark-135m', 'HuggingFaceTB/SmolLM2-135M-Instruct',
'Aravindan/awesome-gpt-2-coder', 'Qwen/Qwen2.5-Coder-0.5B', 'SupraLabs/Supra-50M-Reasoning'
]
ACTIVE_SESSIONS = {}
SESSION_TIMEOUT = 60
def live_count(request: gr.Request):
current_time = time.time()
if request:
ACTIVE_SESSIONS[request.session_hash] = current_time
# Prune
expired = [s for s, t in ACTIVE_SESSIONS.items() if current_time - t > SESSION_TIMEOUT]
for s in expired:
ACTIVE_SESSIONS.pop(s, None)
return len(ACTIVE_SESSIONS)
# Global class to safely manage the loaded model and tokenizer in memory
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
model_manager = ModelManager()
def get_system_stats(request: gr.Request = None):
"""Returns a dictionary of current system metrics with formatted strings."""
mem = psutil.virtual_memory()
disk = psutil.disk_usage('/')
return (
f"CPU\t\t: \t{psutil.cpu_percent(interval=1)}%\n"
f"Mem\t\t: \t{round(mem.used / (1024**3), 2)} / {round(mem.total / (1024**3), 2)} GB\n"
f"Disk\t\t: \t{round(disk.used / (1024**3), 2)} / {round(disk.total / (1024**3), 2)} GB\n"
f"Active\t: \t{len(ACTIVE_SESSIONS) if request is None else live_count(request)} session(s)"
)
def load_new_model(model_id):
"""Loads the model and tokenizer dynamically into the global manager."""
# Clear old model from memory
model_manager.model = None
model_manager.tokenizer = None
yield f"Loading {model_id}..."
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
try:
# Load explicitly for streaming purposes instead of pipeline
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(model_manager.device)
model_manager.tokenizer = tokenizer
model_manager.model = model
yield f"Successfully loaded {model_id} on {model_manager.device.upper()}"
except Exception as e:
yield f"Error loading model: {str(e)}"
def run_inference(user_prompt, max_tokens, temperature, top_k, top_p, rep_penalty, ngram_size, do_sample):
"""Generates text via streaming generator."""
if model_manager.model is None or model_manager.tokenizer is None:
yield "Please load a model first.", "Model not loaded"
return
tokenizer = model_manager.tokenizer
model = model_manager.model
# Tokenize input
inputs = tokenizer([user_prompt], return_tensors="pt").to(model_manager.device)
# Set up the streamer
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
# Adjust variables based on the do_sample logic
if not do_sample:
temperature = 1.0 # Temperature is ignored if do_sample=False, but setting it > 0 avoids config errors
# Generation arguments
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
repetition_penalty=float(rep_penalty),
no_repeat_ngram_size=int(ngram_size),
do_sample=do_sample,
pad_token_id=tokenizer.eos_token_id # Prevents padding warnings
)
start_time = time.time()
# Start generation in a separate background thread
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# Yield output iteratively for the streaming effect
generated_text = user_prompt
token_count = 0
for new_text in streamer:
generated_text += new_text
token_count += 1
duration = time.time() - start_time
tps = token_count / duration if duration > 0 else 0
yield generated_text, f"Speed: {tps:.2f} tokens/sec"
def clean_cache():
if os.path.exists(HF_CACHE_DIR):
shutil.rmtree(HF_CACHE_DIR)
os.makedirs(HF_CACHE_DIR)
return "Cache cleaned successfully!"
return "Cache directory not found."
# Gradio Interface
with gr.Blocks(title="Small MF Model Tester", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🚀 Small Model Evaluation Hub with Streaming")
with gr.Row():
# Left column: Settings & Monitoring
with gr.Column(scale=1):
with gr.Accordion("System Monitoring", open=True):
stats_output = gr.Textbox(label="Live System Stats", show_label=False)
gr.Timer(2).tick(get_system_stats, None, stats_output)
with gr.Group():
gr.Markdown("### Select or Paste custom model id here")
with gr.Row():
model_id_input = gr.Dropdown(choices=MODELS, label="Model", allow_custom_value=True, show_label=False, scale=3)
load_btn = gr.Button("Load", variant="secondary", scale=1)
clean_btn = gr.Button("Clean HF Cache", variant="stop", size="sm")
with gr.Accordion("Generation Configuration", open=False):
do_sample_input = gr.Checkbox(label="Enable Sampling (do_sample)", value=True, info="Uncheck for greedy decoding")
max_tokens_input = gr.Slider(minimum=10, maximum=2048, value=256, step=1, label="Max Output Tokens")
temperature_input = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Higher = more creative")
top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="0 = disabled")
top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus)", info="1.0 = disabled")
rep_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="1.0 = disabled")
ngram_size_input = gr.Slider(minimum=0, maximum=10, value=0, step=1, label="No Repeat N-Gram Size", info="0 = disabled")
# Right column: Interaction
with gr.Column(scale=2):
user_prompt = gr.Textbox(
label="Prompt",
value="Once upon a time in a digital kingdom,",
placeholder="Enter your prompt here...",
lines=5
)
run_btn = gr.Button("Generate text", variant="primary", size="lg")
status_output = gr.Markdown("Status: *Waiting to load model...*")
output_text = gr.Textbox(label="Result", lines=15, buttons=["copy"], autoscroll=True)
# Events
load_btn.click(
fn=load_new_model,
inputs=[model_id_input],
outputs=[status_output]
)
# We use `.click` targeting a generator function, which Gradio naturally treats as a streaming output
run_btn.click(
fn=run_inference,
inputs=[
user_prompt,
max_tokens_input,
temperature_input,
top_k_input,
top_p_input,
rep_penalty_input,
ngram_size_input,
do_sample_input
],
outputs=[output_text, status_output]
)
clean_btn.click(fn=clean_cache, outputs=[status_output])
if __name__ == "__main__":
app.launch() |