File size: 19,512 Bytes
6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 fd406c7 275b4f2 fd406c7 275b4f2 fd406c7 ad95ef1 fd406c7 ad95ef1 fd406c7 ad95ef1 fd406c7 ad95ef1 e6cb750 6bd22b5 ad95ef1 e6cb750 6bd22b5 e6cb750 6bd22b5 e6cb750 6bd22b5 ad95ef1 6bd22b5 d6c6a2d 6bd22b5 ad95ef1 6bd22b5 d6c6a2d fc5f2ab 6bd22b5 7a33ddf ad95ef1 6bd22b5 e6cb750 ad95ef1 a5a87a2 ad95ef1 8758212 a5a87a2 ad95ef1 fc5f2ab ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 e6cb750 ad95ef1 72b6692 e6cb750 ad95ef1 6bd22b5 ad95ef1 e6cb750 72b6692 e6cb750 72b6692 ad95ef1 e6cb750 ad95ef1 275b4f2 ad95ef1 275b4f2 ad95ef1 6bd22b5 ad95ef1 e6cb750 ad95ef1 e6cb750 ad95ef1 fd406c7 ad95ef1 e6cb750 ad95ef1 e24705b ad95ef1 e24705b ad95ef1 e6cb750 ad95ef1 e6cb750 ad95ef1 e24705b ad95ef1 e6cb750 ad95ef1 e24705b ad95ef1 6bd22b5 7a33ddf 6bd22b5 7a33ddf 6bd22b5 69c5213 6bd22b5 69c5213 6bd22b5 69c5213 6bd22b5 69c5213 6bd22b5 e6cb750 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 |
import gradio as gr
from typing import Optional, Dict, List
from datetime import datetime
from .config import AppConfig
from .session_manager import HackerNewsFineTuner
# --- Constants for Labels ---
LABEL_FAV = "👍"
LABEL_NEU = "😐"
LABEL_DIS = "👎"
# --- Session Wrappers ---
def refresh_wrapper(app):
if app is None or callable(app) or isinstance(app, type):
print("Initializing new HackerNewsFineTuner session...")
app = HackerNewsFineTuner(AppConfig)
# Run the refresh logic
# choices_list is a simple list of strings: ["Title 1", "Title 2", ...]
choices_list, log_update = app.refresh_data_and_model()
# Reset user labels
empty_labels = {}
return app, choices_list, empty_labels, log_update
def update_hub_interactive(app, username: Optional[str] = None):
is_logged_in = username is not None
has_model_tuned = app is not None and bool(app.last_hn_dataset)
return gr.update(interactive=is_logged_in), gr.update(interactive=is_logged_in and has_model_tuned)
def on_app_load(app, profile: Optional[gr.OAuthProfile] = None):
# 1. Initialize/Refresh Session
app, stories, labels, text_update = refresh_wrapper(app)
# 2. Extract Username safely
username = profile.username if profile else None
# 3. Get UI Updates using the helper
repo_update, push_update = update_hub_interactive(app, username)
# Return 7 items: App state, Data updates (3), Hub updates (2), Username state (1)
return app, stories, labels, text_update, repo_update, push_update, username
def update_repo_preview(username, repo_name):
"""Updates the markdown preview to show 'username/repo_name'."""
if not username:
return "⚠️ Sign in to see the target repository path."
clean_repo = repo_name.strip() if repo_name else "..."
return f"Target Repository: **`{username}/{clean_repo}`**"
def import_wrapper(app, file):
return app.import_additional_dataset(file)
def export_wrapper(app):
return app.export_dataset()
def download_model_wrapper(app):
return app.download_model()
def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
if oauth_token is None:
return "⚠️ You must be logged in to push to the Hub. Please sign in above."
token_str = oauth_token.token
return app.upload_model(repo_name, token_str)
def training_wrapper(app, stories: List[str], labels: Dict[int, str]):
"""
Parses the Stories and Labels to extract Positive and Negative indices.
stories: List of titles
labels: Dictionary of {index: LABEL_FAV | LABEL_DIS | LABEL_NEU}
"""
pos_ids = []
neg_ids = []
# Iterate through all available stories by index
for i in range(len(stories)):
# Get label for this index, default to Neutral if not set
label = labels.get(i, LABEL_NEU)
if label == LABEL_FAV:
pos_ids.append(i)
elif label == LABEL_DIS:
neg_ids.append(i)
return app.training(pos_ids, neg_ids)
def vibe_check_wrapper(app, text):
return app.get_vibe_check(text)
def mood_feed_wrapper(app):
return app.fetch_and_display_mood_feed()
# --- Interface Setup ---
def build_interface() -> gr.Blocks:
with gr.Blocks(title="EmbeddingGemma Tuning Lab") as demo:
session_state = gr.State()
username_state = gr.State()
# State variables for the Feed List and User Choices
stories_state = gr.State([])
labels_state = gr.State({})
reset_counter = gr.State(0)
with gr.Column():
gr.Markdown("# 🤖 EmbeddingGemma Tuning Lab: Fine-Tuning and Mood Reader")
gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-tuning-lab/blob/main/README.md) for more details.")
with gr.Tab("⚙️ Train & Export"):
# --- Model Indicator ---
gr.Dropdown(
choices=[f"{AppConfig.MODEL_NAME}"],
value=f"{AppConfig.MODEL_NAME}",
label="Base Model for Fine-tuning",
interactive=False
)
# --- Step 0: Login ---
with gr.Accordion("0️⃣ Step 0: Sign In (Optional)", open=True):
gr.Markdown("Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (Step 3).")
with gr.Row():
gr.LoginButton(value="Sign in with Hugging Face")
with gr.Column(scale=3):
gr.Markdown("")
# --- Step 1: Data Selection ---
with gr.Accordion("1️⃣ Step 1: Select Data Source", open=True):
gr.Markdown("Select titles from the live Hacker News feed **OR** upload your own CSV dataset to prepare your training data.")
with gr.Column():
# Option A: Live Feed (Radio List)
with gr.Accordion("Option A: Live Hacker News Feed", open=True):
gr.Markdown("Rate the stories below to define your vibe.\n\n**⚠️ Note: You must select at least one Favorite and one Dislike to run training.**")
with gr.Row():
reset_all_btn = gr.Button("Reset Selection ↺", variant="secondary", scale=1)
with gr.Column(scale=3):
gr.Markdown("")
# Dynamic rendering of the story list
@gr.render(inputs=[stories_state, reset_counter])
def render_story_list(stories, _counter):
if not stories:
gr.Markdown("*No stories loaded. Click 'Reset Model & Fine-tuning state' to fetch data.*")
return
for i, title in enumerate(stories[:10]):
with gr.Row(variant="compact", elem_id=f"story_row_{i}"):
# Title
with gr.Column(scale=2):
gr.Markdown(f"{title}")
# Radio Selection
radio = gr.Radio(
choices=[LABEL_FAV, LABEL_NEU, LABEL_DIS],
value=LABEL_NEU,
show_label=False,
container=False,
min_width=80,
scale=1,
interactive=True
)
# Update logic
def update_label(new_val, current_labels, idx=i):
current_labels[idx] = new_val
return current_labels
radio.change(
fn=update_label,
inputs=[radio, labels_state],
outputs=[labels_state]
)
# Option B: Upload
with gr.Accordion("Option B: Upload Custom Dataset", open=False):
gr.Markdown("Upload a CSV file with columns (no header required, or header ignored if present): `Anchor`, `Positive`, `Negative`.")
gr.Markdown("See also: [example_training.dataset.csv](https://huggingface.co/spaces/google/embeddinggemma-tuning-lab/blob/main/example_training_dataset.csv)<br>Example:<br>`MY_FAVORITE_NEWS,Good Title,Bad Title`")
import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=100)
# --- Step 2: Training ---
with gr.Accordion("2️⃣ Step 2: Run Tuning", open=True):
gr.Markdown("Fine-tune the model using the data selected or uploaded above.")
with gr.Row():
run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1)
clear_reload_btn = gr.Button("Reset Model & Fine-tuning state", scale=1)
output = gr.Textbox(lines=10, label="Training Logs & Search Results", value="Waiting to start...", autoscroll=True)
# --- Step 3: Push to Hub ---
with gr.Accordion("3️⃣ Step 3: Save to Hugging Face Hub (Optional)", open=False):
gr.Markdown("Push your fine-tuned model to your personal Hugging Face account.")
with gr.Row():
repo_name_input = gr.Textbox(label="Target Repository Name", value="my-embeddinggemma-news-vibe", placeholder="e.g., my-embeddinggemma-news-vibe", interactive=False)
push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
push_status = gr.Markdown("")
# --- Step 4: Downloads ---
with gr.Accordion("4️⃣ Step 4: Download Artifacts", open=False):
gr.Markdown("Export your combined dataset or download the fine-tuned model locally.")
with gr.Row():
download_dataset_btn = gr.Button("💾 Export Dataset", interactive=False)
download_model_btn = gr.Button("⬇️ Download Model ZIP", interactive=False)
download_status = gr.Markdown("Ready.")
with gr.Row():
dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
# --- Interaction Logic ---
action_buttons = [
clear_reload_btn,
run_training_btn,
download_dataset_btn,
download_model_btn
]
def set_interactivity(interactive: bool):
"""Helper to lock/unlock all main action buttons."""
return [gr.update(interactive=interactive) for _ in action_buttons]
# 1. App Startup
# ----------------
demo.load(
fn=lambda: set_interactivity(False), outputs=action_buttons
).then(
fn=on_app_load,
inputs=[session_state],
outputs=[session_state, stories_state, labels_state, output, repo_name_input, push_to_hub_btn, username_state]
).then(
fn=update_repo_preview,
inputs=[username_state, repo_name_input],
outputs=[repo_id_preview]
).then(
fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
)
# 2. Reset / Refresh / Clear Selections
# ----------------
clear_reload_btn.click(
fn=lambda: set_interactivity(False), outputs=action_buttons
).then(
fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
).then(
fn=refresh_wrapper,
inputs=[session_state],
outputs=[session_state, stories_state, labels_state, output]
).then(
fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
).then(
fn=update_hub_interactive,
inputs=[session_state, username_state],
outputs=[repo_name_input, push_to_hub_btn]
)
# Reset Selection Button Logic
def reset_all_selections(counter):
# Returns: (incremented counter, empty dict for labels)
return counter + 1, {}
reset_all_btn.click(
fn=reset_all_selections,
inputs=[reset_counter],
outputs=[reset_counter, labels_state]
)
# 3. Import Data
# ----------------
import_file.change(
fn=import_wrapper,
inputs=[session_state, import_file],
outputs=[download_status]
)
# 4. Run Training
# ----------------
run_training_btn.click(
fn=lambda: set_interactivity(False), outputs=action_buttons
).then(
fn=training_wrapper,
inputs=[session_state, stories_state, labels_state],
outputs=[output]
).then(
# Unlock all buttons (including downloads now that we have a model)
fn=lambda: set_interactivity(True), outputs=action_buttons
).then(
fn=update_hub_interactive,
inputs=[session_state, username_state],
outputs=[repo_name_input, push_to_hub_btn]
)
# 5. Downloads
# ----------------
download_dataset_btn.click(
fn=export_wrapper,
inputs=[session_state],
outputs=[dataset_output]
).then(
# Just show the file output if it exists
lambda p: gr.update(visible=True) if p else gr.update(),
inputs=[dataset_output],
outputs=[dataset_output]
)
download_model_btn.click(
# Lock UI
fn=lambda: set_interactivity(False), outputs=action_buttons
).then(
fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
).then(
# Reset previous outputs and show "Zipping..."
fn=lambda: [gr.update(value=None, visible=False), "⏳ Zipping model..."],
outputs=[model_output, download_status]
).then(
# Generate Zip
fn=download_model_wrapper,
inputs=[session_state],
outputs=[model_output]
).then(
# Update UI with result
fn=lambda p: [gr.update(visible=p is not None, value=p), "✅ ZIP ready." if p else "❌ Zipping failed."],
inputs=[model_output],
outputs=[model_output, download_status]
).then(
# Unlock UI
fn=lambda: set_interactivity(True), outputs=action_buttons
).then(
fn=update_hub_interactive,
inputs=[session_state, username_state],
outputs=[repo_name_input, push_to_hub_btn]
)
# 6. Push to Hub
# ----------------
repo_name_input.change(
fn=update_repo_preview,
inputs=[username_state, repo_name_input],
outputs=[repo_id_preview]
)
push_to_hub_btn.click(
fn=lambda: set_interactivity(False), outputs=action_buttons
).then(
fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
).then(
fn=push_to_hub_wrapper,
inputs=[session_state, repo_name_input],
outputs=[push_status]
).then(
fn=lambda: set_interactivity(True), outputs=action_buttons
).then(
fn=update_hub_interactive,
inputs=[session_state, username_state],
outputs=[repo_name_input, push_to_hub_btn]
)
with gr.Tab("📰 Live Ranked Feed"):
with gr.Column():
gr.Markdown(f"## Live Hacker News Feed Vibe")
gr.Markdown(f"This feed uses the current model (base or fine-tuned) to score the vibe of live Hacker News stories against **`{AppConfig.QUERY_ANCHOR}`**.")
feed_output = gr.Markdown(value="Click 'Refresh Feed' to load stories.", label="Latest Stories")
refresh_button = gr.Button("Refresh Feed 🔄", size="lg", variant="primary")
refresh_button.click(fn=mood_feed_wrapper, inputs=[session_state], outputs=feed_output)
with gr.Tab("🧪 Vibe Check Playground"):
with gr.Column():
gr.Markdown(f"## News Similarity Check")
gr.Markdown(f"Enter text to see its similarity to **`{AppConfig.QUERY_ANCHOR}`**.<br>**Vibe Key:** <span style='color:green'>Green = High</span>, <span style='color:yellow'>Yellow = Neutral</span>, <span style='color:red'>Red = Low</span>")
news_input = gr.Textbox(label="Enter News Title or Summary", lines=3, render=False)
gr.Examples(
examples=[
"Global Markets Rally as Inflation Data Shows Unexpected Drop for Third Consecutive Month",
"Astronomers Detect Strong Oxygen Signature on Potentially Habitable Exoplanet",
"City Council Approves Controversial Plan to Ban Cars from Downtown District by 2027",
"Tech Giant Unveils Prototype for \"Invisible\" AR Glasses, Promising a Screen-Free Future",
"Local Library Receives Overdue Book Checked Out in 1948 With An Anonymous Apology Note"
],
inputs=news_input,
label="Try these examples"
)
news_input.render()
vibe_check_btn = gr.Button("Check Similarity", variant="primary")
session_info_display = gr.Markdown()
with gr.Column():
vibe_score = gr.Textbox(label="Score", value="N/A", interactive=False)
vibe_lamp = gr.Textbox(label="Mood Lamp", max_lines=1, elem_id="mood_lamp", interactive=False)
vibe_status = gr.Textbox(label="Status", value="...", interactive=False)
style_thml = gr.HTML(value="<style>#mood_lamp input {background-color: gray;}</style>")
vibe_check_btn.click(
fn=vibe_check_wrapper,
inputs=[session_state, news_input],
outputs=[vibe_score, vibe_status, style_thml, session_info_display]
)
return demo
|