Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| import gc | |
| import os | |
| import shutil | |
| import torch | |
| import psutil | |
| import time | |
| # Define path for HF cache to clean | |
| HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub") | |
| # List of models for autocomplete | |
| MODELS = [ | |
| 'HuggingFaceTB/SmolLM2-135M', 'AxiomicLabs/GPT-X2-125M', 'Qwen/Qwen3-0.6B', | |
| 'facebook/MobileLLM-R1-140M-base', 'SupraLabs/Supra-50M-Base', 'CompactAI-O/Shard-1', | |
| 'SupraLabs/Supra-50M-Instruct', 'HuggingFaceTB/SmolLM-135M', 'facebook/opt-125m', | |
| 'AxiomicLabs/GPT-S-5M', 'openai-community/gpt2', 'LH-Tech-AI/Spark-5M-Base-v4', | |
| 'SupraLabs/Supra-Mini-v5-8M', 'EleutherAI/pythia-70m', 'SupraLabs/Supra-Mini-v4-2M', | |
| 'EleutherAI/pythia-31m', 'StentorLabs/Stentor3-50M', 'StentorLabs/Stentor3-20M', | |
| 'StentorLabs/Portimbria-150M', 'HuggingFaceTB/nanowhale-100m-base', 'EleutherAI/pythia-14m', | |
| 'Harley-ml/Tenete-8M', 'Harley-ml/Dillion-1.2M', 'MihaiPopa-1/CinnabarLM-1.4M-Base', | |
| 'MihaiPopa-1/CinnabarLM-4M-Base', 'MihaiPopa-1/PotentSulfurLM-500K-Base', | |
| 'MihaiPopa-1/CinnabarLM-1.5M-Base', 'Harley-ml/Dillionv2-1.3M', 'Eclipse-Senpai/KeyLM-75M', | |
| 'SupraLabs/Supra-Mini-v6-1M', 'AxiomicLabs/GPT-S-1.4M', 'GODELEV/Archaea-74M', | |
| 'Sandroeth/cali-0.1B', 'veyra-ai/veyra3-5m-base', 'veyra-ai/veyra-30m-base-5b-tokens', | |
| 'ThingAI/Quark-50m', 'ThingAI/Quark-135m', 'HuggingFaceTB/SmolLM2-135M-Instruct', | |
| 'Aravindan/awesome-gpt-2-coder', 'Qwen/Qwen2.5-Coder-0.5B', 'SupraLabs/Supra-50M-Reasoning' | |
| ] | |
| ACTIVE_SESSIONS = {} | |
| SESSION_TIMEOUT = 60 | |
| def live_count(request: gr.Request): | |
| current_time = time.time() | |
| if request: | |
| ACTIVE_SESSIONS[request.session_hash] = current_time | |
| # Prune | |
| expired = [s for s, t in ACTIVE_SESSIONS.items() if current_time - t > SESSION_TIMEOUT] | |
| for s in expired: | |
| ACTIVE_SESSIONS.pop(s, None) | |
| return len(ACTIVE_SESSIONS) | |
| # Global class to safely manage the loaded model and tokenizer in memory | |
| class ModelManager: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_manager = ModelManager() | |
| def get_system_stats(request: gr.Request = None): | |
| """Returns a dictionary of current system metrics with formatted strings.""" | |
| mem = psutil.virtual_memory() | |
| disk = psutil.disk_usage('/') | |
| return ( | |
| f"CPU\t\t: \t{psutil.cpu_percent(interval=1)}%\n" | |
| f"Mem\t\t: \t{round(mem.used / (1024**3), 2)} / {round(mem.total / (1024**3), 2)} GB\n" | |
| f"Disk\t\t: \t{round(disk.used / (1024**3), 2)} / {round(disk.total / (1024**3), 2)} GB\n" | |
| f"Active\t: \t{len(ACTIVE_SESSIONS) if request is None else live_count(request)} session(s)" | |
| ) | |
| def load_new_model(model_id): | |
| """Loads the model and tokenizer dynamically into the global manager.""" | |
| # Clear old model from memory | |
| model_manager.model = None | |
| model_manager.tokenizer = None | |
| yield f"Loading {model_id}..." | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| try: | |
| # Load explicitly for streaming purposes instead of pipeline | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(model_manager.device) | |
| model_manager.tokenizer = tokenizer | |
| model_manager.model = model | |
| yield f"Successfully loaded {model_id} on {model_manager.device.upper()}" | |
| except Exception as e: | |
| yield f"Error loading model: {str(e)}" | |
| def run_inference(user_prompt, max_tokens, temperature, top_k, top_p, rep_penalty, ngram_size, do_sample): | |
| """Generates text via streaming generator.""" | |
| if model_manager.model is None or model_manager.tokenizer is None: | |
| yield "Please load a model first.", "Model not loaded" | |
| return | |
| tokenizer = model_manager.tokenizer | |
| model = model_manager.model | |
| # Tokenize input | |
| inputs = tokenizer([user_prompt], return_tensors="pt").to(model_manager.device) | |
| # Set up the streamer | |
| streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) | |
| # Adjust variables based on the do_sample logic | |
| if not do_sample: | |
| temperature = 1.0 # Temperature is ignored if do_sample=False, but setting it > 0 avoids config errors | |
| # Generation arguments | |
| generate_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| top_k=int(top_k), | |
| top_p=float(top_p), | |
| repetition_penalty=float(rep_penalty), | |
| no_repeat_ngram_size=int(ngram_size), | |
| do_sample=do_sample, | |
| pad_token_id=tokenizer.eos_token_id # Prevents padding warnings | |
| ) | |
| start_time = time.time() | |
| # Start generation in a separate background thread | |
| thread = Thread(target=model.generate, kwargs=generate_kwargs) | |
| thread.start() | |
| # Yield output iteratively for the streaming effect | |
| generated_text = user_prompt | |
| token_count = 0 | |
| for new_text in streamer: | |
| generated_text += new_text | |
| token_count += 1 | |
| duration = time.time() - start_time | |
| tps = token_count / duration if duration > 0 else 0 | |
| yield generated_text, f"Speed: {tps:.2f} tokens/sec" | |
| def clean_cache(): | |
| if os.path.exists(HF_CACHE_DIR): | |
| shutil.rmtree(HF_CACHE_DIR) | |
| os.makedirs(HF_CACHE_DIR) | |
| return "Cache cleaned successfully!" | |
| return "Cache directory not found." | |
| # Gradio Interface | |
| with gr.Blocks(title="Small MF Model Tester", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# 🚀 Small Model Evaluation Hub with Streaming") | |
| with gr.Row(): | |
| # Left column: Settings & Monitoring | |
| with gr.Column(scale=1): | |
| with gr.Accordion("System Monitoring", open=True): | |
| stats_output = gr.Textbox(label="Live System Stats", show_label=False) | |
| gr.Timer(2).tick(get_system_stats, None, stats_output) | |
| with gr.Group(): | |
| gr.Markdown("### Select or Paste custom model id here") | |
| with gr.Row(): | |
| model_id_input = gr.Dropdown(choices=MODELS, label="Model", allow_custom_value=True, show_label=False, scale=3) | |
| load_btn = gr.Button("Load", variant="secondary", scale=1) | |
| clean_btn = gr.Button("Clean HF Cache", variant="stop", size="sm") | |
| with gr.Accordion("Generation Configuration", open=False): | |
| do_sample_input = gr.Checkbox(label="Enable Sampling (do_sample)", value=True, info="Uncheck for greedy decoding") | |
| max_tokens_input = gr.Slider(minimum=10, maximum=2048, value=256, step=1, label="Max Output Tokens") | |
| temperature_input = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Higher = more creative") | |
| top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="0 = disabled") | |
| top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus)", info="1.0 = disabled") | |
| rep_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="1.0 = disabled") | |
| ngram_size_input = gr.Slider(minimum=0, maximum=10, value=0, step=1, label="No Repeat N-Gram Size", info="0 = disabled") | |
| # Right column: Interaction | |
| with gr.Column(scale=2): | |
| user_prompt = gr.Textbox( | |
| label="Prompt", | |
| value="Once upon a time in a digital kingdom,", | |
| placeholder="Enter your prompt here...", | |
| lines=5 | |
| ) | |
| run_btn = gr.Button("Generate text", variant="primary", size="lg") | |
| status_output = gr.Markdown("Status: *Waiting to load model...*") | |
| output_text = gr.Textbox(label="Result", lines=15, buttons=["copy"], autoscroll=True) | |
| # Events | |
| load_btn.click( | |
| fn=load_new_model, | |
| inputs=[model_id_input], | |
| outputs=[status_output] | |
| ) | |
| # We use `.click` targeting a generator function, which Gradio naturally treats as a streaming output | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[ | |
| user_prompt, | |
| max_tokens_input, | |
| temperature_input, | |
| top_k_input, | |
| top_p_input, | |
| rep_penalty_input, | |
| ngram_size_input, | |
| do_sample_input | |
| ], | |
| outputs=[output_text, status_output] | |
| ) | |
| clean_btn.click(fn=clean_cache, outputs=[status_output]) | |
| if __name__ == "__main__": | |
| app.launch() |