File size: 8,936 Bytes
cc8c521
ee07f77
 
cc8c521
 
 
 
 
2962bbd
cc8c521
 
 
 
55c8db2
 
473340d
55c8db2
 
 
 
 
 
 
 
 
 
 
32f15bb
f7b504f
55c8db2
 
a1baed0
4692774
a1baed0
 
 
 
 
 
 
 
 
 
 
 
 
ee07f77
 
 
 
 
 
 
 
 
b685633
cc8c521
 
 
a1baed0
109ea8d
 
 
4692774
13c3fc6
cc8c521
 
ee07f77
cc8c521
ee07f77
 
32f15bb
cc8c521
 
 
 
 
ee07f77
 
 
 
 
 
 
32f15bb
cc8c521
32f15bb
cc8c521
ee07f77
 
 
2962bbd
ee07f77
cc8c521
ee07f77
 
 
 
 
 
 
134a4fc
ee07f77
 
 
 
 
 
 
 
 
41aa821
 
 
ee07f77
 
 
 
 
41aa821
2962bbd
 
ee07f77
 
 
 
 
 
2962bbd
ee07f77
 
2962bbd
 
 
 
cc8c521
 
 
 
 
 
 
 
 
ee07f77
cc8c521
ee07f77
 
cc8c521
ee07f77
3238ae1
ee07f77
41aa821
20ab2a9
ee07f77
41aa821
ee07f77
a0adfc8
ee07f77
 
 
 
 
9e4c82c
ee07f77
 
20ab2a9
ee07f77
 
 
 
 
 
9e4c82c
3238ae1
 
ee07f77
 
a0adfc8
ee07f77
 
 
0e2a4a3
a817aa1
0e2a4a3
cc8c521
 
 
 
 
ee07f77
cc8c521
 
ee07f77
cc8c521
 
ee07f77
 
 
 
 
 
 
 
 
 
a817aa1
cc8c521
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import gc
import os
import shutil
import torch
import psutil
import time

# Define path for HF cache to clean
HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub")

# List of models for autocomplete
MODELS = [
    'HuggingFaceTB/SmolLM2-135M', 'AxiomicLabs/GPT-X2-125M', 'Qwen/Qwen3-0.6B',
    'facebook/MobileLLM-R1-140M-base', 'SupraLabs/Supra-50M-Base', 'CompactAI-O/Shard-1',
    'SupraLabs/Supra-50M-Instruct', 'HuggingFaceTB/SmolLM-135M', 'facebook/opt-125m',
    'AxiomicLabs/GPT-S-5M', 'openai-community/gpt2', 'LH-Tech-AI/Spark-5M-Base-v4',
    'SupraLabs/Supra-Mini-v5-8M', 'EleutherAI/pythia-70m', 'SupraLabs/Supra-Mini-v4-2M',
    'EleutherAI/pythia-31m', 'StentorLabs/Stentor3-50M', 'StentorLabs/Stentor3-20M',
    'StentorLabs/Portimbria-150M', 'HuggingFaceTB/nanowhale-100m-base', 'EleutherAI/pythia-14m',
    'Harley-ml/Tenete-8M', 'Harley-ml/Dillion-1.2M', 'MihaiPopa-1/CinnabarLM-1.4M-Base',
    'MihaiPopa-1/CinnabarLM-4M-Base', 'MihaiPopa-1/PotentSulfurLM-500K-Base',
    'MihaiPopa-1/CinnabarLM-1.5M-Base', 'Harley-ml/Dillionv2-1.3M', 'Eclipse-Senpai/KeyLM-75M',
    'SupraLabs/Supra-Mini-v6-1M', 'AxiomicLabs/GPT-S-1.4M', 'GODELEV/Archaea-74M',
    'Sandroeth/cali-0.1B', 'veyra-ai/veyra3-5m-base', 'veyra-ai/veyra-30m-base-5b-tokens',
    'ThingAI/Quark-50m', 'ThingAI/Quark-135m', 'HuggingFaceTB/SmolLM2-135M-Instruct', 
    'Aravindan/awesome-gpt-2-coder', 'Qwen/Qwen2.5-Coder-0.5B', 'SupraLabs/Supra-50M-Reasoning'
]

ACTIVE_SESSIONS = {}
SESSION_TIMEOUT = 60

def live_count(request: gr.Request):
    current_time = time.time()
    if request:
        ACTIVE_SESSIONS[request.session_hash] = current_time
        
    # Prune
    expired = [s for s, t in ACTIVE_SESSIONS.items() if current_time - t > SESSION_TIMEOUT]
    for s in expired:
        ACTIVE_SESSIONS.pop(s, None)
        
    return len(ACTIVE_SESSIONS)

# Global class to safely manage the loaded model and tokenizer in memory
class ModelManager:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

model_manager = ModelManager()

def get_system_stats(request: gr.Request = None):
    """Returns a dictionary of current system metrics with formatted strings."""
    mem = psutil.virtual_memory()
    disk = psutil.disk_usage('/')
    return (
        f"CPU\t\t: \t{psutil.cpu_percent(interval=1)}%\n"
        f"Mem\t\t: \t{round(mem.used / (1024**3), 2)} / {round(mem.total / (1024**3), 2)} GB\n"
        f"Disk\t\t: \t{round(disk.used / (1024**3), 2)} / {round(disk.total / (1024**3), 2)} GB\n"
        f"Active\t: \t{len(ACTIVE_SESSIONS) if request is None else live_count(request)} session(s)"
    )

def load_new_model(model_id):
    """Loads the model and tokenizer dynamically into the global manager."""
    # Clear old model from memory
    model_manager.model = None
    model_manager.tokenizer = None
    yield f"Loading {model_id}..."
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    try:
        # Load explicitly for streaming purposes instead of pipeline
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(model_manager.device)
        
        model_manager.tokenizer = tokenizer
        model_manager.model = model
        
        yield f"Successfully loaded {model_id} on {model_manager.device.upper()}"
    except Exception as e:
        yield f"Error loading model: {str(e)}"

def run_inference(user_prompt, max_tokens, temperature, top_k, top_p, rep_penalty, ngram_size, do_sample):
    """Generates text via streaming generator."""
    if model_manager.model is None or model_manager.tokenizer is None:
        yield "Please load a model first.", "Model not loaded"
        return
    
    tokenizer = model_manager.tokenizer
    model = model_manager.model
    
    # Tokenize input
    inputs = tokenizer([user_prompt], return_tensors="pt").to(model_manager.device)
    
    # Set up the streamer
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    
    # Adjust variables based on the do_sample logic
    if not do_sample:
        temperature = 1.0 # Temperature is ignored if do_sample=False, but setting it > 0 avoids config errors

    # Generation arguments
    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=int(max_tokens),
        temperature=float(temperature),
        top_k=int(top_k),
        top_p=float(top_p),
        repetition_penalty=float(rep_penalty),
        no_repeat_ngram_size=int(ngram_size),
        do_sample=do_sample,
        pad_token_id=tokenizer.eos_token_id # Prevents padding warnings
    )

    start_time = time.time()
    # Start generation in a separate background thread
    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()
    
    # Yield output iteratively for the streaming effect
    generated_text = user_prompt
    token_count = 0
    for new_text in streamer:
        generated_text += new_text
        token_count += 1
        duration = time.time() - start_time
        tps = token_count / duration if duration > 0 else 0
        yield generated_text, f"Speed: {tps:.2f} tokens/sec"

def clean_cache():
    if os.path.exists(HF_CACHE_DIR):
        shutil.rmtree(HF_CACHE_DIR)
        os.makedirs(HF_CACHE_DIR)
        return "Cache cleaned successfully!"
    return "Cache directory not found."

# Gradio Interface
with gr.Blocks(title="Small MF Model Tester", theme=gr.themes.Soft()) as app:
    
    gr.Markdown("# 🚀 Small Model Evaluation Hub with Streaming")

    with gr.Row():
        # Left column: Settings & Monitoring
        with gr.Column(scale=1):
            
            with gr.Accordion("System Monitoring", open=True):
                stats_output = gr.Textbox(label="Live System Stats", show_label=False)
                gr.Timer(2).tick(get_system_stats, None, stats_output)

            with gr.Group():
                gr.Markdown("### Select or Paste custom model id here")
                with gr.Row():
                    model_id_input = gr.Dropdown(choices=MODELS, label="Model", allow_custom_value=True, show_label=False, scale=3)
                    load_btn = gr.Button("Load", variant="secondary", scale=1)
                
                clean_btn = gr.Button("Clean HF Cache", variant="stop", size="sm")

            with gr.Accordion("Generation Configuration", open=False):
                do_sample_input = gr.Checkbox(label="Enable Sampling (do_sample)", value=True, info="Uncheck for greedy decoding")
                max_tokens_input = gr.Slider(minimum=10, maximum=2048, value=256, step=1, label="Max Output Tokens")
                temperature_input = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Higher = more creative")
                
                top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="0 = disabled")
                top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus)", info="1.0 = disabled")
                rep_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="1.0 = disabled")
                ngram_size_input = gr.Slider(minimum=0, maximum=10, value=0, step=1, label="No Repeat N-Gram Size", info="0 = disabled")

        # Right column: Interaction
        with gr.Column(scale=2):
            user_prompt = gr.Textbox(
                label="Prompt", 
                value="Once upon a time in a digital kingdom,", 
                placeholder="Enter your prompt here...", 
                lines=5
            )
            run_btn = gr.Button("Generate text", variant="primary", size="lg")
            status_output = gr.Markdown("Status: *Waiting to load model...*")
            output_text = gr.Textbox(label="Result", lines=15, buttons=["copy"], autoscroll=True)
    
    # Events
    load_btn.click(
        fn=load_new_model, 
        inputs=[model_id_input], 
        outputs=[status_output]
    )
    
    # We use `.click` targeting a generator function, which Gradio naturally treats as a streaming output
    run_btn.click(
        fn=run_inference,
        inputs=[
            user_prompt, 
            max_tokens_input, 
            temperature_input, 
            top_k_input, 
            top_p_input, 
            rep_penalty_input, 
            ngram_size_input,
            do_sample_input
        ],
        outputs=[output_text, status_output]
    )
    
    clean_btn.click(fn=clean_cache, outputs=[status_output])

if __name__ == "__main__":
    app.launch()