stanley-00 commited on
Commit
ee07f77
·
verified ·
1 Parent(s): 9e4c82c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -35
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
  import gc
4
  import os
5
  import shutil
@@ -26,6 +27,15 @@ MODELS = [
26
  'ThingAI/Quark-50m', 'ThingAI/Quark-135m'
27
  ]
28
 
 
 
 
 
 
 
 
 
 
29
  def get_system_stats():
30
  """Returns a dictionary of current system metrics with formatted strings."""
31
  mem = psutil.virtual_memory()
@@ -37,31 +47,68 @@ def get_system_stats():
37
  }
38
 
39
  def load_new_model(model_id):
 
40
  # Clear old model from memory
 
 
41
  gc.collect()
42
  if torch.cuda.is_available():
43
  torch.cuda.empty_cache()
44
 
45
  try:
46
- # Load a text-generation pipeline with trust_remote_code enabled
47
- pipe = pipeline("text-generation", model=model_id, trust_remote_code=True)
48
- return pipe, f"Successfully loaded {model_id}"
 
 
 
 
 
49
  except Exception as e:
50
- return None, f"Error loading model: {str(e)}"
51
 
52
- def run_inference(model, user_prompt, max_tokens, temperature, top_k):
53
- if not model:
54
- return "Please load a model first."
 
 
55
 
56
- # Run inference with additional sampling parameters
57
- result = model(
58
- user_prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  max_new_tokens=int(max_tokens),
60
  temperature=float(temperature),
61
  top_k=int(top_k),
62
- do_sample=True
 
 
 
 
63
  )
64
- return result[0]['generated_text']
 
 
 
 
 
 
 
 
 
65
 
66
  def clean_cache():
67
  if os.path.exists(HF_CACHE_DIR):
@@ -71,47 +118,69 @@ def clean_cache():
71
  return "Cache directory not found."
72
 
73
  # Gradio Interface
74
- with gr.Blocks(title="Small MF Model Tester") as app:
75
-
76
- current_model = gr.State(None)
77
 
 
 
78
  with gr.Row():
79
- # Left column: Settings
80
  with gr.Column(scale=1):
81
- # Stats Section
82
  with gr.Accordion("System Monitoring", open=True):
83
  stats_output = gr.JSON(label="Live System Stats")
84
- gr.Timer(5).tick(get_system_stats, None, stats_output)
85
 
 
 
 
 
 
 
 
 
86
 
87
- model_id_input = gr.Dropdown(choices=MODELS, label="Model", allow_custom_value=True, show_label=False)
 
 
 
 
 
 
 
 
 
88
 
89
- max_tokens_input = gr.Slider(minimum=10, maximum=1024, value=128, step=1, label="Max Output Tokens")
90
- temperature_input = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
91
- top_k_input = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-K Sampling")
92
-
93
- load_btn = gr.Button("Load", variant="secondary")
94
- clean_btn = gr.Button("Clean", variant="stop")
95
- status_output = gr.Markdown("Status: Waiting to load model...")
96
-
97
-
98
-
99
  # Right column: Interaction
100
  with gr.Column(scale=2):
101
- user_prompt = gr.Textbox(label="Prompt", value="Once upon a time,", placeholder="Enter your prompt here...", lines=5)
102
- run_btn = gr.Button("Run Inference", variant="primary")
103
- output_text = gr.Textbox(label="Result", lines=10)
 
 
 
 
 
104
 
105
  # Events
106
  load_btn.click(
107
  fn=load_new_model,
108
  inputs=[model_id_input],
109
- outputs=[current_model, status_output]
110
  )
111
 
 
112
  run_btn.click(
113
  fn=run_inference,
114
- inputs=[current_model, user_prompt, max_tokens_input, temperature_input, top_k_input],
 
 
 
 
 
 
 
 
 
115
  outputs=[output_text]
116
  )
117
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
+ from threading import Thread
4
  import gc
5
  import os
6
  import shutil
 
27
  'ThingAI/Quark-50m', 'ThingAI/Quark-135m'
28
  ]
29
 
30
+ # Global class to safely manage the loaded model and tokenizer in memory
31
+ class ModelManager:
32
+ def __init__(self):
33
+ self.model = None
34
+ self.tokenizer = None
35
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+ model_manager = ModelManager()
38
+
39
  def get_system_stats():
40
  """Returns a dictionary of current system metrics with formatted strings."""
41
  mem = psutil.virtual_memory()
 
47
  }
48
 
49
  def load_new_model(model_id):
50
+ """Loads the model and tokenizer dynamically into the global manager."""
51
  # Clear old model from memory
52
+ model_manager.model = None
53
+ model_manager.tokenizer = None
54
  gc.collect()
55
  if torch.cuda.is_available():
56
  torch.cuda.empty_cache()
57
 
58
  try:
59
+ # Load explicitly for streaming purposes instead of pipeline
60
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
61
+ model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to(model_manager.device)
62
+
63
+ model_manager.tokenizer = tokenizer
64
+ model_manager.model = model
65
+
66
+ return f"Successfully loaded {model_id} on {model_manager.device.upper()}"
67
  except Exception as e:
68
+ return f"Error loading model: {str(e)}"
69
 
70
+ def run_inference(user_prompt, max_tokens, temperature, top_k, top_p, rep_penalty, ngram_size, do_sample):
71
+ """Generates text via streaming generator."""
72
+ if model_manager.model is None or model_manager.tokenizer is None:
73
+ yield "Please load a model first."
74
+ return
75
 
76
+ tokenizer = model_manager.tokenizer
77
+ model = model_manager.model
78
+
79
+ # Tokenize input
80
+ inputs = tokenizer([user_prompt], return_tensors="pt").to(model_manager.device)
81
+
82
+ # Set up the streamer
83
+ streamer = TextIteratorStreamer(tokenizer, timeout=15.0, skip_prompt=True, skip_special_tokens=True)
84
+
85
+ # Adjust variables based on the do_sample logic
86
+ if not do_sample:
87
+ temperature = 1.0 # Temperature is ignored if do_sample=False, but setting it > 0 avoids config errors
88
+
89
+ # Generation arguments
90
+ generate_kwargs = dict(
91
+ **inputs,
92
+ streamer=streamer,
93
  max_new_tokens=int(max_tokens),
94
  temperature=float(temperature),
95
  top_k=int(top_k),
96
+ top_p=float(top_p),
97
+ repetition_penalty=float(rep_penalty),
98
+ no_repeat_ngram_size=int(ngram_size),
99
+ do_sample=do_sample,
100
+ pad_token_id=tokenizer.eos_token_id # Prevents padding warnings
101
  )
102
+
103
+ # Start generation in a separate background thread
104
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
105
+ thread.start()
106
+
107
+ # Yield output iteratively for the streaming effect
108
+ generated_text = user_prompt
109
+ for new_text in streamer:
110
+ generated_text += new_text
111
+ yield generated_text
112
 
113
  def clean_cache():
114
  if os.path.exists(HF_CACHE_DIR):
 
118
  return "Cache directory not found."
119
 
120
  # Gradio Interface
121
+ with gr.Blocks(title="Small MF Model Tester", theme=gr.themes.Soft()) as app:
 
 
122
 
123
+ gr.Markdown("# 🚀 Small Model Evaluation Hub with Streaming")
124
+
125
  with gr.Row():
126
+ # Left column: Settings & Monitoring
127
  with gr.Column(scale=1):
128
+
129
  with gr.Accordion("System Monitoring", open=True):
130
  stats_output = gr.JSON(label="Live System Stats")
131
+ gr.Timer(2).tick(get_system_stats, None, stats_output)
132
 
133
+ with gr.Group():
134
+ gr.Markdown("### Model Loader")
135
+ with gr.Row():
136
+ model_id_input = gr.Dropdown(choices=MODELS, label="Model", allow_custom_value=True, show_label=False, scale=3)
137
+ load_btn = gr.Button("Load", variant="secondary", scale=1)
138
+
139
+ status_output = gr.Markdown("Status: *Waiting to load model...*")
140
+ clean_btn = gr.Button("Clean HF Cache", variant="stop", size="sm")
141
 
142
+ with gr.Accordion("Generation Configuration", open=False):
143
+ do_sample_input = gr.Checkbox(label="Enable Sampling (do_sample)", value=True, info="Uncheck for greedy decoding")
144
+ max_tokens_input = gr.Slider(minimum=10, maximum=2048, value=128, step=1, label="Max Output Tokens")
145
+ temperature_input = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Higher = more creative")
146
+
147
+ gr.Markdown("#### Advanced Sampling Constraints")
148
+ top_k_input = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="0 = disabled")
149
+ top_p_input = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus)", info="1.0 = disabled")
150
+ rep_penalty_input = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="1.0 = disabled")
151
+ ngram_size_input = gr.Slider(minimum=0, maximum=10, value=0, step=1, label="No Repeat N-Gram Size", info="0 = disabled")
152
 
 
 
 
 
 
 
 
 
 
 
153
  # Right column: Interaction
154
  with gr.Column(scale=2):
155
+ user_prompt = gr.Textbox(
156
+ label="Prompt",
157
+ value="Once upon a time in a digital kingdom,",
158
+ placeholder="Enter your prompt here...",
159
+ lines=5
160
+ )
161
+ run_btn = gr.Button("Generate text (Stream)", variant="primary", size="lg")
162
+ output_text = gr.Textbox(label="Result", lines=15, show_copy_button=True)
163
 
164
  # Events
165
  load_btn.click(
166
  fn=load_new_model,
167
  inputs=[model_id_input],
168
+ outputs=[status_output]
169
  )
170
 
171
+ # We use `.click` targeting a generator function, which Gradio naturally treats as a streaming output
172
  run_btn.click(
173
  fn=run_inference,
174
+ inputs=[
175
+ user_prompt,
176
+ max_tokens_input,
177
+ temperature_input,
178
+ top_k_input,
179
+ top_p_input,
180
+ rep_penalty_input,
181
+ ngram_size_input,
182
+ do_sample_input
183
+ ],
184
  outputs=[output_text]
185
  )
186