Upload folder using huggingface_hub
Browse files- app.py +2 -525
- src/session_manager.py +306 -0
- src/ui.py +220 -0
- src/vibe_logic.py +1 -1
app.py
CHANGED
|
@@ -1,530 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
import os
|
| 3 |
-
import shutil
|
| 4 |
-
import time
|
| 5 |
-
import csv
|
| 6 |
-
import uuid
|
| 7 |
-
from itertools import cycle
|
| 8 |
-
from typing import List, Iterable, Tuple, Optional, Callable
|
| 9 |
-
from datetime import datetime
|
| 10 |
-
|
| 11 |
-
# Import modules
|
| 12 |
-
from src.data_fetcher import read_hacker_news_rss, format_published_time
|
| 13 |
-
from src.model_trainer import (
|
| 14 |
-
authenticate_hf,
|
| 15 |
-
train_with_dataset,
|
| 16 |
-
get_top_hits,
|
| 17 |
-
load_embedding_model,
|
| 18 |
-
upload_model_to_hub
|
| 19 |
-
)
|
| 20 |
-
from src.config import AppConfig
|
| 21 |
-
from src.vibe_logic import VibeChecker
|
| 22 |
-
from sentence_transformers import SentenceTransformer
|
| 23 |
-
|
| 24 |
-
# --- Main Application Class (Session Scoped) ---
|
| 25 |
-
|
| 26 |
-
class HackerNewsFineTuner:
|
| 27 |
-
"""
|
| 28 |
-
Encapsulates all application logic and state for a single user session.
|
| 29 |
-
"""
|
| 30 |
-
|
| 31 |
-
def __init__(self, config: AppConfig = AppConfig):
|
| 32 |
-
# --- Dependencies ---
|
| 33 |
-
self.config = config
|
| 34 |
-
|
| 35 |
-
# --- Session Identification ---
|
| 36 |
-
self.session_id = str(uuid.uuid4())
|
| 37 |
-
|
| 38 |
-
# Define session-specific paths to allow simultaneous training
|
| 39 |
-
self.session_root = self.config.ARTIFACTS_DIR / self.session_id
|
| 40 |
-
self.output_dir = self.session_root / "embedding_gemma_finetuned"
|
| 41 |
-
self.dataset_export_file = self.session_root / "training_dataset.csv"
|
| 42 |
-
|
| 43 |
-
# Setup directories
|
| 44 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
| 45 |
-
print(f"[{self.session_id}] New session started. Artifacts: {self.session_root}")
|
| 46 |
-
|
| 47 |
-
# --- Application State ---
|
| 48 |
-
self.model: Optional[SentenceTransformer] = None
|
| 49 |
-
self.vibe_checker: Optional[VibeChecker] = None
|
| 50 |
-
self.titles: List[str] = []
|
| 51 |
-
self.target_titles: List[str] = []
|
| 52 |
-
self.number_list: List[int] = []
|
| 53 |
-
self.last_hn_dataset: List[List[str]] = []
|
| 54 |
-
self.imported_dataset: List[List[str]] = []
|
| 55 |
-
|
| 56 |
-
# Authenticate once (global)
|
| 57 |
-
authenticate_hf(self.config.HF_TOKEN)
|
| 58 |
-
|
| 59 |
-
def _update_vibe_checker(self):
|
| 60 |
-
"""Initializes or updates the VibeChecker with the current model state."""
|
| 61 |
-
if self.model:
|
| 62 |
-
self.vibe_checker = VibeChecker(
|
| 63 |
-
model=self.model,
|
| 64 |
-
query_anchor=self.config.QUERY_ANCHOR,
|
| 65 |
-
task_name=self.config.TASK_NAME
|
| 66 |
-
)
|
| 67 |
-
else:
|
| 68 |
-
self.vibe_checker = None
|
| 69 |
-
|
| 70 |
-
## Data and Model Management ##
|
| 71 |
-
|
| 72 |
-
def refresh_data_and_model(self) -> Tuple[gr.update, gr.update]:
|
| 73 |
-
"""
|
| 74 |
-
Reloads model and fetches data.
|
| 75 |
-
"""
|
| 76 |
-
print(f"[{self.session_id}] Reloading model and data...")
|
| 77 |
-
|
| 78 |
-
self.last_hn_dataset = []
|
| 79 |
-
self.imported_dataset = []
|
| 80 |
-
|
| 81 |
-
# 1. Reload the base embedding model
|
| 82 |
-
try:
|
| 83 |
-
self.model = load_embedding_model(self.config.MODEL_NAME)
|
| 84 |
-
self._update_vibe_checker()
|
| 85 |
-
except Exception as e:
|
| 86 |
-
error_msg = f"CRITICAL ERROR: Model failed to load. {e}"
|
| 87 |
-
print(error_msg)
|
| 88 |
-
self.model = None
|
| 89 |
-
self._update_vibe_checker()
|
| 90 |
-
return (
|
| 91 |
-
gr.update(choices=[], label="Model Load Failed"),
|
| 92 |
-
gr.update(value=error_msg)
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
# 2. Fetch fresh news data
|
| 96 |
-
news_feed, status_msg = read_hacker_news_rss(self.config)
|
| 97 |
-
titles_out, target_titles_out = [], []
|
| 98 |
-
status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"
|
| 99 |
-
|
| 100 |
-
if news_feed is not None and news_feed.entries:
|
| 101 |
-
titles_out = [item.title for item in news_feed.entries[:self.config.TOP_TITLES_COUNT]]
|
| 102 |
-
target_titles_out = [item.title for item in news_feed.entries[self.config.TOP_TITLES_COUNT:]]
|
| 103 |
-
else:
|
| 104 |
-
titles_out = ["Error fetching news.", "Check console."]
|
| 105 |
-
gr.Warning(f"Data reload failed. {status_msg}")
|
| 106 |
-
|
| 107 |
-
self.titles = titles_out
|
| 108 |
-
self.target_titles = target_titles_out
|
| 109 |
-
self.number_list = list(range(len(self.titles)))
|
| 110 |
-
|
| 111 |
-
return (
|
| 112 |
-
gr.update(
|
| 113 |
-
choices=self.titles,
|
| 114 |
-
label=f"Hacker News Top {len(self.titles)} (Select your favorites)"
|
| 115 |
-
),
|
| 116 |
-
gr.update(value=status_value)
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
# --- Import Dataset/Export ---
|
| 120 |
-
def import_additional_dataset(self, file_path: str) -> str:
|
| 121 |
-
if not file_path:
|
| 122 |
-
return "Please upload a CSV file."
|
| 123 |
-
new_dataset, num_imported = [], 0
|
| 124 |
-
try:
|
| 125 |
-
with open(file_path, 'r', newline='', encoding='utf-8') as f:
|
| 126 |
-
reader = csv.reader(f)
|
| 127 |
-
try:
|
| 128 |
-
header = next(reader)
|
| 129 |
-
if not (header and header[0].lower().strip() == 'anchor'):
|
| 130 |
-
f.seek(0)
|
| 131 |
-
except StopIteration:
|
| 132 |
-
return "Error: Uploaded file is empty."
|
| 133 |
-
|
| 134 |
-
for row in reader:
|
| 135 |
-
if len(row) == 3:
|
| 136 |
-
new_dataset.append([s.strip() for s in row])
|
| 137 |
-
num_imported += 1
|
| 138 |
-
if num_imported == 0:
|
| 139 |
-
raise ValueError("No valid rows found.")
|
| 140 |
-
self.imported_dataset = new_dataset
|
| 141 |
-
return f"Imported {num_imported} triplets."
|
| 142 |
-
except Exception as e:
|
| 143 |
-
return f"Import failed: {e}"
|
| 144 |
-
|
| 145 |
-
def export_dataset(self) -> Optional[str]:
|
| 146 |
-
if not self.last_hn_dataset:
|
| 147 |
-
gr.Warning("No dataset generated yet.")
|
| 148 |
-
return None
|
| 149 |
-
|
| 150 |
-
file_path = self.dataset_export_file
|
| 151 |
-
try:
|
| 152 |
-
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
| 153 |
-
writer = csv.writer(f)
|
| 154 |
-
writer.writerow(['Anchor', 'Positive', 'Negative'])
|
| 155 |
-
writer.writerows(self.last_hn_dataset)
|
| 156 |
-
gr.Info(f"Dataset exported.")
|
| 157 |
-
return str(file_path)
|
| 158 |
-
except Exception as e:
|
| 159 |
-
gr.Error(f"Export failed: {e}")
|
| 160 |
-
return None
|
| 161 |
-
|
| 162 |
-
def download_model(self) -> Optional[str]:
|
| 163 |
-
if not os.path.exists(self.output_dir):
|
| 164 |
-
gr.Warning("No model trained yet.")
|
| 165 |
-
return None
|
| 166 |
-
|
| 167 |
-
timestamp = int(time.time())
|
| 168 |
-
try:
|
| 169 |
-
base_name = self.session_root / f"model_finetuned_{timestamp}"
|
| 170 |
-
archive_path = shutil.make_archive(
|
| 171 |
-
base_name=str(base_name),
|
| 172 |
-
format='zip',
|
| 173 |
-
root_dir=self.output_dir,
|
| 174 |
-
)
|
| 175 |
-
gr.Info(f"Model zipped.")
|
| 176 |
-
return archive_path
|
| 177 |
-
except Exception as e:
|
| 178 |
-
gr.Error(f"Zip failed: {e}")
|
| 179 |
-
return None
|
| 180 |
-
|
| 181 |
-
def upload_model(self, repo_name: str, oauth_token_str: str) -> str:
|
| 182 |
-
"""
|
| 183 |
-
Calls the model trainer upload function using the session's output directory.
|
| 184 |
-
"""
|
| 185 |
-
if not os.path.exists(self.output_dir):
|
| 186 |
-
return "β Error: No trained model found in this session. Run training first."
|
| 187 |
-
if not repo_name.strip():
|
| 188 |
-
return "β Error: Please specify a repository name."
|
| 189 |
-
|
| 190 |
-
return upload_model_to_hub(self.output_dir, repo_name, oauth_token_str)
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
## Training Logic ##
|
| 194 |
-
def _create_hn_dataset(self, selected_ids: List[int]) -> Tuple[List[List[str]], str, str]:
|
| 195 |
-
total_ids, selected_ids = set(self.number_list), set(selected_ids)
|
| 196 |
-
non_selected_ids = total_ids - selected_ids
|
| 197 |
-
is_minority = len(selected_ids) < (len(total_ids) / 2)
|
| 198 |
-
|
| 199 |
-
anchor_ids, pool_ids = (non_selected_ids, list(selected_ids)) if is_minority else (selected_ids, list(non_selected_ids))
|
| 200 |
-
|
| 201 |
-
def get_titles(anchor_id, pool_id):
|
| 202 |
-
return (self.titles[pool_id], self.titles[anchor_id]) if is_minority else (self.titles[anchor_id], self.titles[pool_id])
|
| 203 |
-
|
| 204 |
-
if not pool_ids or not anchor_ids:
|
| 205 |
-
return [], "", ""
|
| 206 |
-
|
| 207 |
-
fav_idx = pool_ids[0] if is_minority else list(anchor_ids)[0]
|
| 208 |
-
non_fav_idx = list(anchor_ids)[0] if is_minority else pool_ids[0]
|
| 209 |
-
|
| 210 |
-
hn_dataset = []
|
| 211 |
-
pool_cycler = cycle(pool_ids)
|
| 212 |
-
for anchor_id in sorted(list(anchor_ids)):
|
| 213 |
-
fav, non_fav = get_titles(anchor_id, next(pool_cycler))
|
| 214 |
-
hn_dataset.append([self.config.QUERY_ANCHOR, fav, non_fav])
|
| 215 |
-
|
| 216 |
-
return hn_dataset, self.titles[fav_idx], self.titles[non_fav_idx]
|
| 217 |
-
|
| 218 |
-
def training(self, selected_ids: List[int]) -> str:
|
| 219 |
-
if self.model is None:
|
| 220 |
-
raise gr.Error("Model not loaded.")
|
| 221 |
-
if not selected_ids:
|
| 222 |
-
raise gr.Error("Select at least one title.")
|
| 223 |
-
if len(selected_ids) == len(self.number_list):
|
| 224 |
-
raise gr.Error("Cannot select all titles.")
|
| 225 |
-
|
| 226 |
-
hn_dataset, _, _ = self._create_hn_dataset(selected_ids)
|
| 227 |
-
self.last_hn_dataset = hn_dataset
|
| 228 |
-
final_dataset = self.last_hn_dataset + self.imported_dataset
|
| 229 |
-
|
| 230 |
-
if not final_dataset:
|
| 231 |
-
raise gr.Error("Dataset is empty.")
|
| 232 |
-
|
| 233 |
-
def semantic_search_fn() -> str:
|
| 234 |
-
return get_top_hits(model=self.model, target_titles=self.target_titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
|
| 235 |
-
|
| 236 |
-
result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
|
| 237 |
-
print(f"[{self.session_id}] Starting Training...")
|
| 238 |
-
|
| 239 |
-
train_with_dataset(
|
| 240 |
-
model=self.model,
|
| 241 |
-
dataset=final_dataset,
|
| 242 |
-
output_dir=self.output_dir,
|
| 243 |
-
task_name=self.config.TASK_NAME,
|
| 244 |
-
search_fn=semantic_search_fn
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
-
self._update_vibe_checker()
|
| 248 |
-
print(f"[{self.session_id}] Training Complete.")
|
| 249 |
-
|
| 250 |
-
result += "### Search (After):\n" + f"{semantic_search_fn()}"
|
| 251 |
-
return result
|
| 252 |
-
|
| 253 |
-
## Vibe Check Logic ##
|
| 254 |
-
def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
|
| 255 |
-
info_text = f"**Session:** {self.session_id[:6]}<br>**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}"
|
| 256 |
-
|
| 257 |
-
if not self.vibe_checker:
|
| 258 |
-
return "N/A", "Model Loading...", gr.update(value=self._generate_vibe_html("gray")), info_text
|
| 259 |
-
if not news_text or len(news_text.split()) < 3:
|
| 260 |
-
return "N/A", "Text too short", gr.update(value=self._generate_vibe_html("white")), info_text
|
| 261 |
-
|
| 262 |
-
try:
|
| 263 |
-
vibe_result = self.vibe_checker.check(news_text)
|
| 264 |
-
status = vibe_result.status_html.split('>')[1].split('<')[0]
|
| 265 |
-
return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_html(vibe_result.color_hsl)), info_text
|
| 266 |
-
except Exception as e:
|
| 267 |
-
return "N/A", f"Error: {e}", gr.update(value=self._generate_vibe_html("gray")), info_text
|
| 268 |
-
|
| 269 |
-
def _generate_vibe_html(self, color: str) -> str:
|
| 270 |
-
return f'<div style="background-color: {color}; height: 100px; border-radius: 12px; border: 2px solid #ccc;"></div>'
|
| 271 |
-
|
| 272 |
-
## Mood Reader Logic ##
|
| 273 |
-
def fetch_and_display_mood_feed(self) -> str:
|
| 274 |
-
if not self.vibe_checker:
|
| 275 |
-
return "Model not ready. Please wait or reload."
|
| 276 |
-
|
| 277 |
-
feed, status = read_hacker_news_rss(self.config)
|
| 278 |
-
if not feed or not feed.entries:
|
| 279 |
-
return f"**Feed Error:** {status}"
|
| 280 |
-
|
| 281 |
-
scored_entries = []
|
| 282 |
-
for entry in feed.entries:
|
| 283 |
-
title = entry.get('title')
|
| 284 |
-
if not title: continue
|
| 285 |
-
|
| 286 |
-
vibe_result = self.vibe_checker.check(title)
|
| 287 |
-
scored_entries.append({
|
| 288 |
-
"title": title,
|
| 289 |
-
"link": entry.get('link', '#'),
|
| 290 |
-
"comments": entry.get('comments', '#'),
|
| 291 |
-
"published": format_published_time(entry.published_parsed),
|
| 292 |
-
"mood": vibe_result
|
| 293 |
-
})
|
| 294 |
-
|
| 295 |
-
scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True)
|
| 296 |
-
|
| 297 |
-
md = (f"## Hacker News Top Stories\n"
|
| 298 |
-
f"**Session:** {self.session_id[:6]}<br>"
|
| 299 |
-
f"**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}<br>"
|
| 300 |
-
f"**Updated:** {datetime.now().strftime('%H:%M:%S')}\n\n"
|
| 301 |
-
"| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n")
|
| 302 |
-
|
| 303 |
-
for item in scored_entries:
|
| 304 |
-
md += (f"| {item['mood'].status_html} "
|
| 305 |
-
f"| {item['mood'].raw_score:.4f} "
|
| 306 |
-
f"| [{item['title']}]({item['link']}) "
|
| 307 |
-
f"| [Comments]({item['comments']}) "
|
| 308 |
-
f"| {item['published']} |\n")
|
| 309 |
-
return md
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
# --- Session Wrappers ---
|
| 313 |
-
|
| 314 |
-
def refresh_wrapper(app):
|
| 315 |
-
"""
|
| 316 |
-
Initializes the session if it's not already created, then runs the refresh.
|
| 317 |
-
Returns the app instance to update the State.
|
| 318 |
-
"""
|
| 319 |
-
if app is None or callable(app) or isinstance(app, type):
|
| 320 |
-
print("Initializing new HackerNewsFineTuner session...")
|
| 321 |
-
app = HackerNewsFineTuner(AppConfig)
|
| 322 |
-
|
| 323 |
-
# Run the refresh logic
|
| 324 |
-
update1, update2 = app.refresh_data_and_model()
|
| 325 |
-
|
| 326 |
-
# Return 3 items: The App Instance (for State), Choice Update, Text Update
|
| 327 |
-
return app, update1, update2
|
| 328 |
-
|
| 329 |
-
def import_wrapper(app, file):
|
| 330 |
-
return app.import_additional_dataset(file)
|
| 331 |
-
|
| 332 |
-
def export_wrapper(app):
|
| 333 |
-
return app.export_dataset()
|
| 334 |
-
|
| 335 |
-
def download_model_wrapper(app):
|
| 336 |
-
return app.download_model()
|
| 337 |
-
|
| 338 |
-
def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
|
| 339 |
-
"""
|
| 340 |
-
Wrapper for pushing the model to the Hugging Face Hub.
|
| 341 |
-
Gradio automatically injects 'oauth_token' if the user is logged in via LoginButton.
|
| 342 |
-
"""
|
| 343 |
-
if oauth_token is None:
|
| 344 |
-
return "β οΈ You must be logged in to push to the Hub. Please sign in above."
|
| 345 |
-
|
| 346 |
-
# Extract the token string from the OAuthToken object
|
| 347 |
-
token_str = oauth_token.token
|
| 348 |
-
return app.upload_model(repo_name, token_str)
|
| 349 |
-
|
| 350 |
-
def training_wrapper(app, selected_ids):
|
| 351 |
-
return app.training(selected_ids)
|
| 352 |
-
|
| 353 |
-
def vibe_check_wrapper(app, text):
|
| 354 |
-
return app.get_vibe_check(text)
|
| 355 |
-
|
| 356 |
-
def mood_feed_wrapper(app):
|
| 357 |
-
return app.fetch_and_display_mood_feed()
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
# --- Interface Setup ---
|
| 361 |
-
|
| 362 |
-
def build_interface() -> gr.Blocks:
|
| 363 |
-
with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
|
| 364 |
-
# Initialize state as None. It will be populated by refresh_wrapper on load.
|
| 365 |
-
session_state = gr.State()
|
| 366 |
-
|
| 367 |
-
with gr.Column():
|
| 368 |
-
gr.Markdown("# π€ EmbeddingGemma Modkit: Fine-Tuning and Mood Reader")
|
| 369 |
-
gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/README.md) for more details.")
|
| 370 |
-
gr.LoginButton(value="(Optional) Sign in to Hugging Face, if you want to push fine-tuned model to your repo.")
|
| 371 |
-
|
| 372 |
-
with gr.Tab("π Fine-Tuning & Evaluation"):
|
| 373 |
-
with gr.Column():
|
| 374 |
-
gr.Markdown("## Fine-Tuning & Semantic Search\nSelect titles to fine-tune the model towards making them more similar to **`MY_FAVORITE_NEWS`**.")
|
| 375 |
-
with gr.Row():
|
| 376 |
-
favorite_list = gr.CheckboxGroup(choices=[], type="index", label="Hacker News Top Stories", show_select_all=True)
|
| 377 |
-
output = gr.Textbox(lines=14, label="Training and Search Results", value="Loading data...")
|
| 378 |
-
|
| 379 |
-
with gr.Row():
|
| 380 |
-
clear_reload_btn = gr.Button("Clear & Reload")
|
| 381 |
-
run_training_btn = gr.Button("π Run Fine-Tuning", variant="primary")
|
| 382 |
-
|
| 383 |
-
gr.Markdown("--- \n ## Dataset & Model Management")
|
| 384 |
-
gr.Markdown("To train on your own data, upload a CSV file with the following columns (no header required, or header ignored if present):\n1. **Anchor**: A fixed anchor phrase, `MY_FAVORITE_NEWS`.\n2. **Positive**: A title or contents that you like.\n3. **Negative**: A title or contents that you don't like.\n\nExample CSV Row:\n```\nMY_FAVORITE_NEWS,What is machine learning?,How to write a compiler from scratch.\n```")
|
| 385 |
-
import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
|
| 386 |
-
|
| 387 |
-
with gr.Row():
|
| 388 |
-
download_dataset_btn = gr.Button("πΎ Export Dataset")
|
| 389 |
-
download_model_btn = gr.Button("β¬οΈ Download Fine-Tuned Model")
|
| 390 |
-
|
| 391 |
-
download_status = gr.Markdown("Ready.")
|
| 392 |
-
|
| 393 |
-
with gr.Row():
|
| 394 |
-
dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
|
| 395 |
-
model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
|
| 396 |
-
|
| 397 |
-
gr.Markdown("### βοΈ Publish to Hugging Face Hub")
|
| 398 |
-
with gr.Row():
|
| 399 |
-
repo_name_input = gr.Textbox(label="Target Repository Name", placeholder="e.g., my-news-vibe-model")
|
| 400 |
-
push_to_hub_btn = gr.Button("Push to Hub", variant="secondary")
|
| 401 |
-
|
| 402 |
-
push_status = gr.Markdown("")
|
| 403 |
-
|
| 404 |
-
# --- Interactions ---
|
| 405 |
-
|
| 406 |
-
# 1. Initial Load: Initialize State and Load Data
|
| 407 |
-
demo.load(
|
| 408 |
-
fn=refresh_wrapper,
|
| 409 |
-
inputs=[session_state],
|
| 410 |
-
outputs=[session_state, favorite_list, output]
|
| 411 |
-
)
|
| 412 |
-
|
| 413 |
-
buttons_to_lock = [
|
| 414 |
-
clear_reload_btn,
|
| 415 |
-
run_training_btn,
|
| 416 |
-
download_dataset_btn,
|
| 417 |
-
download_model_btn,
|
| 418 |
-
push_to_hub_btn
|
| 419 |
-
]
|
| 420 |
-
|
| 421 |
-
# 2. Buttons
|
| 422 |
-
clear_reload_btn.click(
|
| 423 |
-
fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
|
| 424 |
-
outputs=buttons_to_lock
|
| 425 |
-
).then(
|
| 426 |
-
fn=refresh_wrapper,
|
| 427 |
-
inputs=[session_state],
|
| 428 |
-
outputs=[session_state, favorite_list, output]
|
| 429 |
-
).then(
|
| 430 |
-
fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
|
| 431 |
-
outputs=buttons_to_lock
|
| 432 |
-
)
|
| 433 |
-
|
| 434 |
-
run_training_btn.click(
|
| 435 |
-
fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
|
| 436 |
-
outputs=buttons_to_lock
|
| 437 |
-
).then(
|
| 438 |
-
fn=training_wrapper,
|
| 439 |
-
inputs=[session_state, favorite_list],
|
| 440 |
-
outputs=[output]
|
| 441 |
-
).then(
|
| 442 |
-
fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
|
| 443 |
-
outputs=buttons_to_lock
|
| 444 |
-
)
|
| 445 |
-
|
| 446 |
-
import_file.change(
|
| 447 |
-
fn=import_wrapper,
|
| 448 |
-
inputs=[session_state, import_file],
|
| 449 |
-
outputs=[download_status]
|
| 450 |
-
)
|
| 451 |
-
|
| 452 |
-
download_dataset_btn.click(
|
| 453 |
-
fn=export_wrapper,
|
| 454 |
-
inputs=[session_state],
|
| 455 |
-
outputs=[dataset_output]
|
| 456 |
-
).then(
|
| 457 |
-
lambda p: gr.update(visible=True) if p else gr.update(), inputs=[dataset_output], outputs=[dataset_output]
|
| 458 |
-
)
|
| 459 |
-
|
| 460 |
-
download_model_btn.click(
|
| 461 |
-
fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
|
| 462 |
-
outputs=buttons_to_lock
|
| 463 |
-
).then(
|
| 464 |
-
lambda: [gr.update(value=None, visible=False), "Zipping..."], None, [model_output, download_status], queue=False
|
| 465 |
-
).then(
|
| 466 |
-
fn=download_model_wrapper,
|
| 467 |
-
inputs=[session_state],
|
| 468 |
-
outputs=[model_output]
|
| 469 |
-
).then(
|
| 470 |
-
lambda p: [gr.update(visible=p is not None, value=p), "ZIP ready." if p else "Zipping failed."], [model_output], [model_output, download_status]
|
| 471 |
-
).then(
|
| 472 |
-
fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
|
| 473 |
-
outputs=buttons_to_lock
|
| 474 |
-
)
|
| 475 |
-
|
| 476 |
-
# Push to Hub Interaction
|
| 477 |
-
push_to_hub_btn.click(
|
| 478 |
-
fn=push_to_hub_wrapper,
|
| 479 |
-
inputs=[session_state, repo_name_input],
|
| 480 |
-
outputs=[push_status]
|
| 481 |
-
)
|
| 482 |
-
|
| 483 |
-
with gr.Tab("π° Hacker News Similarity Check"):
|
| 484 |
-
with gr.Column():
|
| 485 |
-
gr.Markdown(f"## Live Hacker News Feed Vibe")
|
| 486 |
-
gr.Markdown(f"This feed uses the current model (base or fine-tuned) to score the vibe of live Hacker News stories against **`{AppConfig.QUERY_ANCHOR}`**.")
|
| 487 |
-
feed_output = gr.Markdown(value="Click 'Refresh Feed' to load stories.", label="Latest Stories")
|
| 488 |
-
refresh_button = gr.Button("Refresh Feed π", size="lg", variant="primary")
|
| 489 |
-
refresh_button.click(fn=mood_feed_wrapper, inputs=[session_state], outputs=feed_output)
|
| 490 |
-
|
| 491 |
-
with gr.Tab("π‘ Similarity Lamp"):
|
| 492 |
-
with gr.Column():
|
| 493 |
-
gr.Markdown(f"## News Similarity Check")
|
| 494 |
-
gr.Markdown(f"Enter text to see its similarity to **`{AppConfig.QUERY_ANCHOR}`**.\n**Vibe Key:** Green = High, Red = Low")
|
| 495 |
-
news_input = gr.Textbox(label="Enter News Title or Summary", lines=3)
|
| 496 |
-
vibe_check_btn = gr.Button("Check Similarity", variant="primary")
|
| 497 |
-
|
| 498 |
-
gr.Examples(
|
| 499 |
-
examples=[
|
| 500 |
-
"Global Markets Rally as Inflation Data Shows Unexpected Drop for Third Consecutive Month",
|
| 501 |
-
"Astronomers Detect Strong Oxygen Signature on Potentially Habitable Exoplanet",
|
| 502 |
-
"City Council Approves Controversial Plan to Ban Cars from Downtown District by 2027",
|
| 503 |
-
"Tech Giant Unveils Prototype for \"Invisible\" AR Glasses, Promising a Screen-Free Future",
|
| 504 |
-
"Local Library Receives Overdue Book Checked Out in 1948 With An Anonymous Apology Note"
|
| 505 |
-
],
|
| 506 |
-
inputs=news_input,
|
| 507 |
-
label="Try these examples"
|
| 508 |
-
)
|
| 509 |
-
|
| 510 |
-
session_info_display = gr.Markdown()
|
| 511 |
-
|
| 512 |
-
with gr.Row():
|
| 513 |
-
vibe_color_block = gr.HTML(value='<div style="background-color: gray; height: 100px;"></div>', label="Mood Lamp")
|
| 514 |
-
with gr.Column():
|
| 515 |
-
vibe_score = gr.Textbox(label="Score", value="N/A", interactive=False)
|
| 516 |
-
vibe_status = gr.Textbox(label="Status", value="...", interactive=False)
|
| 517 |
-
|
| 518 |
-
vibe_check_btn.click(
|
| 519 |
-
fn=vibe_check_wrapper,
|
| 520 |
-
inputs=[session_state, news_input],
|
| 521 |
-
outputs=[vibe_score, vibe_status, vibe_color_block, session_info_display]
|
| 522 |
-
)
|
| 523 |
-
|
| 524 |
-
return demo
|
| 525 |
|
| 526 |
if __name__ == "__main__":
|
| 527 |
app_demo = build_interface()
|
| 528 |
print("Starting Multi-User Gradio App...")
|
| 529 |
app_demo.queue()
|
| 530 |
-
app_demo.launch()
|
|
|
|
| 1 |
+
from src.ui import build_interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
if __name__ == "__main__":
|
| 4 |
app_demo = build_interface()
|
| 5 |
print("Starting Multi-User Gradio App...")
|
| 6 |
app_demo.queue()
|
| 7 |
+
app_demo.launch()
|
src/session_manager.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import time
|
| 4 |
+
import csv
|
| 5 |
+
import uuid
|
| 6 |
+
from itertools import cycle
|
| 7 |
+
from typing import List, Tuple, Optional
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
import gradio as gr # Needed for gr.update, gr.Warning, gr.Info, gr.Error
|
| 10 |
+
|
| 11 |
+
from .data_fetcher import read_hacker_news_rss, format_published_time
|
| 12 |
+
from .model_trainer import (
|
| 13 |
+
authenticate_hf,
|
| 14 |
+
train_with_dataset,
|
| 15 |
+
get_top_hits,
|
| 16 |
+
load_embedding_model,
|
| 17 |
+
upload_model_to_hub
|
| 18 |
+
)
|
| 19 |
+
from .config import AppConfig
|
| 20 |
+
from .vibe_logic import VibeChecker
|
| 21 |
+
from sentence_transformers import SentenceTransformer
|
| 22 |
+
|
| 23 |
+
class HackerNewsFineTuner:
|
| 24 |
+
"""
|
| 25 |
+
Encapsulates all application logic and state for a single user session.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, config: AppConfig = AppConfig):
|
| 29 |
+
# --- Dependencies ---
|
| 30 |
+
self.config = config
|
| 31 |
+
|
| 32 |
+
# --- Session Identification ---
|
| 33 |
+
self.session_id = str(uuid.uuid4())
|
| 34 |
+
|
| 35 |
+
# Define session-specific paths to allow simultaneous training
|
| 36 |
+
self.session_root = self.config.ARTIFACTS_DIR / self.session_id
|
| 37 |
+
self.output_dir = self.session_root / "embedding_gemma_finetuned"
|
| 38 |
+
self.dataset_export_file = self.session_root / "training_dataset.csv"
|
| 39 |
+
|
| 40 |
+
# Setup directories
|
| 41 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 42 |
+
print(f"[{self.session_id}] New session started. Artifacts: {self.session_root}")
|
| 43 |
+
|
| 44 |
+
# --- Application State ---
|
| 45 |
+
self.model: Optional[SentenceTransformer] = None
|
| 46 |
+
self.vibe_checker: Optional[VibeChecker] = None
|
| 47 |
+
self.titles: List[str] = []
|
| 48 |
+
self.target_titles: List[str] = []
|
| 49 |
+
self.number_list: List[int] = []
|
| 50 |
+
self.last_hn_dataset: List[List[str]] = []
|
| 51 |
+
self.imported_dataset: List[List[str]] = []
|
| 52 |
+
|
| 53 |
+
# Authenticate once (global)
|
| 54 |
+
authenticate_hf(self.config.HF_TOKEN)
|
| 55 |
+
|
| 56 |
+
def _update_vibe_checker(self):
|
| 57 |
+
"""Initializes or updates the VibeChecker with the current model state."""
|
| 58 |
+
if self.model:
|
| 59 |
+
self.vibe_checker = VibeChecker(
|
| 60 |
+
model=self.model,
|
| 61 |
+
query_anchor=self.config.QUERY_ANCHOR,
|
| 62 |
+
task_name=self.config.TASK_NAME
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
self.vibe_checker = None
|
| 66 |
+
|
| 67 |
+
## Data and Model Management ##
|
| 68 |
+
|
| 69 |
+
def refresh_data_and_model(self) -> Tuple[gr.update, gr.update]:
|
| 70 |
+
"""
|
| 71 |
+
Reloads model and fetches data.
|
| 72 |
+
"""
|
| 73 |
+
print(f"[{self.session_id}] Reloading model and data...")
|
| 74 |
+
|
| 75 |
+
self.last_hn_dataset = []
|
| 76 |
+
self.imported_dataset = []
|
| 77 |
+
|
| 78 |
+
# 1. Reload the base embedding model
|
| 79 |
+
try:
|
| 80 |
+
self.model = load_embedding_model(self.config.MODEL_NAME)
|
| 81 |
+
self._update_vibe_checker()
|
| 82 |
+
except Exception as e:
|
| 83 |
+
error_msg = f"CRITICAL ERROR: Model failed to load. {e}"
|
| 84 |
+
print(error_msg)
|
| 85 |
+
self.model = None
|
| 86 |
+
self._update_vibe_checker()
|
| 87 |
+
return (
|
| 88 |
+
gr.update(choices=[], label="Model Load Failed"),
|
| 89 |
+
gr.update(value=error_msg)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# 2. Fetch fresh news data
|
| 93 |
+
news_feed, status_msg = read_hacker_news_rss(self.config)
|
| 94 |
+
titles_out, target_titles_out = [], []
|
| 95 |
+
status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"
|
| 96 |
+
|
| 97 |
+
if news_feed is not None and news_feed.entries:
|
| 98 |
+
titles_out = [item.title for item in news_feed.entries[:self.config.TOP_TITLES_COUNT]]
|
| 99 |
+
target_titles_out = [item.title for item in news_feed.entries[self.config.TOP_TITLES_COUNT:]]
|
| 100 |
+
else:
|
| 101 |
+
titles_out = ["Error fetching news.", "Check console."]
|
| 102 |
+
gr.Warning(f"Data reload failed. {status_msg}")
|
| 103 |
+
|
| 104 |
+
self.titles = titles_out
|
| 105 |
+
self.target_titles = target_titles_out
|
| 106 |
+
self.number_list = list(range(len(self.titles)))
|
| 107 |
+
|
| 108 |
+
return (
|
| 109 |
+
gr.update(
|
| 110 |
+
choices=self.titles,
|
| 111 |
+
label=f"Hacker News Top {len(self.titles)} (Select your favorites)"
|
| 112 |
+
),
|
| 113 |
+
gr.update(value=status_value)
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# --- Import Dataset/Export ---
|
| 117 |
+
def import_additional_dataset(self, file_path: str) -> str:
|
| 118 |
+
if not file_path:
|
| 119 |
+
return "Please upload a CSV file."
|
| 120 |
+
new_dataset, num_imported = [], 0
|
| 121 |
+
try:
|
| 122 |
+
with open(file_path, 'r', newline='', encoding='utf-8') as f:
|
| 123 |
+
reader = csv.reader(f)
|
| 124 |
+
try:
|
| 125 |
+
header = next(reader)
|
| 126 |
+
if not (header and header[0].lower().strip() == 'anchor'):
|
| 127 |
+
f.seek(0)
|
| 128 |
+
except StopIteration:
|
| 129 |
+
return "Error: Uploaded file is empty."
|
| 130 |
+
|
| 131 |
+
for row in reader:
|
| 132 |
+
if len(row) == 3:
|
| 133 |
+
new_dataset.append([s.strip() for s in row])
|
| 134 |
+
num_imported += 1
|
| 135 |
+
if num_imported == 0:
|
| 136 |
+
raise ValueError("No valid rows found.")
|
| 137 |
+
self.imported_dataset = new_dataset
|
| 138 |
+
return f"Imported {num_imported} triplets."
|
| 139 |
+
except Exception as e:
|
| 140 |
+
return f"Import failed: {e}"
|
| 141 |
+
|
| 142 |
+
def export_dataset(self) -> Optional[str]:
|
| 143 |
+
if not self.last_hn_dataset:
|
| 144 |
+
gr.Warning("No dataset generated yet.")
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
file_path = self.dataset_export_file
|
| 148 |
+
try:
|
| 149 |
+
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
| 150 |
+
writer = csv.writer(f)
|
| 151 |
+
writer.writerow(['Anchor', 'Positive', 'Negative'])
|
| 152 |
+
writer.writerows(self.last_hn_dataset)
|
| 153 |
+
gr.Info(f"Dataset exported.")
|
| 154 |
+
return str(file_path)
|
| 155 |
+
except Exception as e:
|
| 156 |
+
gr.Error(f"Export failed: {e}")
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
def download_model(self) -> Optional[str]:
|
| 160 |
+
if not os.path.exists(self.output_dir):
|
| 161 |
+
gr.Warning("No model trained yet.")
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
timestamp = int(time.time())
|
| 165 |
+
try:
|
| 166 |
+
base_name = self.session_root / f"model_finetuned_{timestamp}"
|
| 167 |
+
archive_path = shutil.make_archive(
|
| 168 |
+
base_name=str(base_name),
|
| 169 |
+
format='zip',
|
| 170 |
+
root_dir=self.output_dir,
|
| 171 |
+
)
|
| 172 |
+
gr.Info(f"Model zipped.")
|
| 173 |
+
return archive_path
|
| 174 |
+
except Exception as e:
|
| 175 |
+
gr.Error(f"Zip failed: {e}")
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
def upload_model(self, repo_name: str, oauth_token_str: str) -> str:
|
| 179 |
+
"""
|
| 180 |
+
Calls the model trainer upload function using the session's output directory.
|
| 181 |
+
"""
|
| 182 |
+
if not os.path.exists(self.output_dir):
|
| 183 |
+
return "β Error: No trained model found in this session. Run training first."
|
| 184 |
+
if not repo_name.strip():
|
| 185 |
+
return "β Error: Please specify a repository name."
|
| 186 |
+
|
| 187 |
+
return upload_model_to_hub(self.output_dir, repo_name, oauth_token_str)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
## Training Logic ##
|
| 191 |
+
def _create_hn_dataset(self, selected_ids: List[int]) -> Tuple[List[List[str]], str, str]:
|
| 192 |
+
total_ids, selected_ids = set(self.number_list), set(selected_ids)
|
| 193 |
+
non_selected_ids = total_ids - selected_ids
|
| 194 |
+
is_minority = len(selected_ids) < (len(total_ids) / 2)
|
| 195 |
+
|
| 196 |
+
anchor_ids, pool_ids = (non_selected_ids, list(selected_ids)) if is_minority else (selected_ids, list(non_selected_ids))
|
| 197 |
+
|
| 198 |
+
def get_titles(anchor_id, pool_id):
|
| 199 |
+
return (self.titles[pool_id], self.titles[anchor_id]) if is_minority else (self.titles[anchor_id], self.titles[pool_id])
|
| 200 |
+
|
| 201 |
+
if not pool_ids or not anchor_ids:
|
| 202 |
+
return [], "", ""
|
| 203 |
+
|
| 204 |
+
fav_idx = pool_ids[0] if is_minority else list(anchor_ids)[0]
|
| 205 |
+
non_fav_idx = list(anchor_ids)[0] if is_minority else pool_ids[0]
|
| 206 |
+
|
| 207 |
+
hn_dataset = []
|
| 208 |
+
pool_cycler = cycle(pool_ids)
|
| 209 |
+
for anchor_id in sorted(list(anchor_ids)):
|
| 210 |
+
fav, non_fav = get_titles(anchor_id, next(pool_cycler))
|
| 211 |
+
hn_dataset.append([self.config.QUERY_ANCHOR, fav, non_fav])
|
| 212 |
+
|
| 213 |
+
return hn_dataset, self.titles[fav_idx], self.titles[non_fav_idx]
|
| 214 |
+
|
| 215 |
+
def training(self, selected_ids: List[int]) -> str:
|
| 216 |
+
if self.model is None:
|
| 217 |
+
raise gr.Error("Model not loaded.")
|
| 218 |
+
if not selected_ids:
|
| 219 |
+
raise gr.Error("Select at least one title.")
|
| 220 |
+
if len(selected_ids) == len(self.number_list):
|
| 221 |
+
raise gr.Error("Cannot select all titles.")
|
| 222 |
+
|
| 223 |
+
hn_dataset, _, _ = self._create_hn_dataset(selected_ids)
|
| 224 |
+
self.last_hn_dataset = hn_dataset
|
| 225 |
+
final_dataset = self.last_hn_dataset + self.imported_dataset
|
| 226 |
+
|
| 227 |
+
if not final_dataset:
|
| 228 |
+
raise gr.Error("Dataset is empty.")
|
| 229 |
+
|
| 230 |
+
def semantic_search_fn() -> str:
|
| 231 |
+
return get_top_hits(model=self.model, target_titles=self.target_titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
|
| 232 |
+
|
| 233 |
+
result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
|
| 234 |
+
print(f"[{self.session_id}] Starting Training...")
|
| 235 |
+
|
| 236 |
+
train_with_dataset(
|
| 237 |
+
model=self.model,
|
| 238 |
+
dataset=final_dataset,
|
| 239 |
+
output_dir=self.output_dir,
|
| 240 |
+
task_name=self.config.TASK_NAME,
|
| 241 |
+
search_fn=semantic_search_fn
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
self._update_vibe_checker()
|
| 245 |
+
print(f"[{self.session_id}] Training Complete.")
|
| 246 |
+
|
| 247 |
+
result += "### Search (After):\n" + f"{semantic_search_fn()}"
|
| 248 |
+
return result
|
| 249 |
+
|
| 250 |
+
## Vibe Check Logic ##
|
| 251 |
+
def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
|
| 252 |
+
info_text = f"**Session:** {self.session_id[:6]}<br>**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}"
|
| 253 |
+
|
| 254 |
+
if not self.vibe_checker:
|
| 255 |
+
return "N/A", "Model Loading...", gr.update(value=self._generate_vibe_html("gray")), info_text
|
| 256 |
+
if not news_text or len(news_text.split()) < 3:
|
| 257 |
+
return "N/A", "Text too short", gr.update(value=self._generate_vibe_html("white")), info_text
|
| 258 |
+
|
| 259 |
+
try:
|
| 260 |
+
vibe_result = self.vibe_checker.check(news_text)
|
| 261 |
+
status = vibe_result.status_html.split('>')[1].split('<')[0]
|
| 262 |
+
return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_html(vibe_result.color_hsl)), info_text
|
| 263 |
+
except Exception as e:
|
| 264 |
+
return "N/A", f"Error: {e}", gr.update(value=self._generate_vibe_html("gray")), info_text
|
| 265 |
+
|
| 266 |
+
def _generate_vibe_html(self, color: str) -> str:
|
| 267 |
+
return f'<div style="background-color: {color}; height: 100px; border-radius: 12px; border: 2px solid #ccc;"></div>'
|
| 268 |
+
|
| 269 |
+
## Mood Reader Logic ##
|
| 270 |
+
def fetch_and_display_mood_feed(self) -> str:
|
| 271 |
+
if not self.vibe_checker:
|
| 272 |
+
return "Model not ready. Please wait or reload."
|
| 273 |
+
|
| 274 |
+
feed, status = read_hacker_news_rss(self.config)
|
| 275 |
+
if not feed or not feed.entries:
|
| 276 |
+
return f"**Feed Error:** {status}"
|
| 277 |
+
|
| 278 |
+
scored_entries = []
|
| 279 |
+
for entry in feed.entries:
|
| 280 |
+
title = entry.get('title')
|
| 281 |
+
if not title: continue
|
| 282 |
+
|
| 283 |
+
vibe_result = self.vibe_checker.check(title)
|
| 284 |
+
scored_entries.append({
|
| 285 |
+
"title": title,
|
| 286 |
+
"link": entry.get('link', '#'),
|
| 287 |
+
"comments": entry.get('comments', '#'),
|
| 288 |
+
"published": format_published_time(entry.published_parsed),
|
| 289 |
+
"mood": vibe_result
|
| 290 |
+
})
|
| 291 |
+
|
| 292 |
+
scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True)
|
| 293 |
+
|
| 294 |
+
md = (f"## Hacker News Top Stories\n"
|
| 295 |
+
f"**Session:** {self.session_id[:6]}<br>"
|
| 296 |
+
f"**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}<br>"
|
| 297 |
+
f"**Updated:** {datetime.now().strftime('%H:%M:%S')}\n\n"
|
| 298 |
+
"| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n")
|
| 299 |
+
|
| 300 |
+
for item in scored_entries:
|
| 301 |
+
md += (f"| {item['mood'].status_html} "
|
| 302 |
+
f"| {item['mood'].raw_score:.4f} "
|
| 303 |
+
f"| [{item['title']}]({item['link']}) "
|
| 304 |
+
f"| [Comments]({item['comments']}) "
|
| 305 |
+
f"| {item['published']} |\n")
|
| 306 |
+
return md
|
src/ui.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
from .config import AppConfig
|
| 6 |
+
from .session_manager import HackerNewsFineTuner
|
| 7 |
+
|
| 8 |
+
# --- Session Wrappers ---
|
| 9 |
+
|
| 10 |
+
def refresh_wrapper(app):
|
| 11 |
+
"""
|
| 12 |
+
Initializes the session if it's not already created, then runs the refresh.
|
| 13 |
+
Returns the app instance to update the State.
|
| 14 |
+
"""
|
| 15 |
+
if app is None or callable(app) or isinstance(app, type):
|
| 16 |
+
print("Initializing new HackerNewsFineTuner session...")
|
| 17 |
+
app = HackerNewsFineTuner(AppConfig)
|
| 18 |
+
|
| 19 |
+
# Run the refresh logic
|
| 20 |
+
update1, update2 = app.refresh_data_and_model()
|
| 21 |
+
|
| 22 |
+
# Return 3 items: The App Instance (for State), Choice Update, Text Update
|
| 23 |
+
return app, update1, update2
|
| 24 |
+
|
| 25 |
+
def import_wrapper(app, file):
|
| 26 |
+
return app.import_additional_dataset(file)
|
| 27 |
+
|
| 28 |
+
def export_wrapper(app):
|
| 29 |
+
return app.export_dataset()
|
| 30 |
+
|
| 31 |
+
def download_model_wrapper(app):
|
| 32 |
+
return app.download_model()
|
| 33 |
+
|
| 34 |
+
def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
|
| 35 |
+
"""
|
| 36 |
+
Wrapper for pushing the model to the Hugging Face Hub.
|
| 37 |
+
Gradio automatically injects 'oauth_token' if the user is logged in via LoginButton.
|
| 38 |
+
"""
|
| 39 |
+
if oauth_token is None:
|
| 40 |
+
return "β οΈ You must be logged in to push to the Hub. Please sign in above."
|
| 41 |
+
|
| 42 |
+
# Extract the token string from the OAuthToken object
|
| 43 |
+
token_str = oauth_token.token
|
| 44 |
+
return app.upload_model(repo_name, token_str)
|
| 45 |
+
|
| 46 |
+
def training_wrapper(app, selected_ids):
|
| 47 |
+
return app.training(selected_ids)
|
| 48 |
+
|
| 49 |
+
def vibe_check_wrapper(app, text):
|
| 50 |
+
return app.get_vibe_check(text)
|
| 51 |
+
|
| 52 |
+
def mood_feed_wrapper(app):
|
| 53 |
+
return app.fetch_and_display_mood_feed()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# --- Interface Setup ---
|
| 57 |
+
|
| 58 |
+
def build_interface() -> gr.Blocks:
|
| 59 |
+
with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
|
| 60 |
+
# Initialize state as None. It will be populated by refresh_wrapper on load.
|
| 61 |
+
session_state = gr.State()
|
| 62 |
+
|
| 63 |
+
with gr.Column():
|
| 64 |
+
gr.Markdown("# π€ EmbeddingGemma Modkit: Fine-Tuning and Mood Reader")
|
| 65 |
+
gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/README.md) for more details.")
|
| 66 |
+
gr.LoginButton(value="(Optional) Sign in to Hugging Face, if you want to push fine-tuned model to your repo.")
|
| 67 |
+
|
| 68 |
+
with gr.Tab("π Fine-Tuning & Evaluation"):
|
| 69 |
+
with gr.Column():
|
| 70 |
+
gr.Markdown("## Fine-Tuning & Semantic Search\nSelect titles to fine-tune the model towards making them more similar to **`MY_FAVORITE_NEWS`**.")
|
| 71 |
+
with gr.Row():
|
| 72 |
+
favorite_list = gr.CheckboxGroup(choices=[], type="index", label="Hacker News Top Stories", show_select_all=True)
|
| 73 |
+
output = gr.Textbox(lines=14, label="Training and Search Results", value="Loading data...")
|
| 74 |
+
|
| 75 |
+
with gr.Row():
|
| 76 |
+
clear_reload_btn = gr.Button("Clear & Reload")
|
| 77 |
+
run_training_btn = gr.Button("π Run Fine-Tuning", variant="primary")
|
| 78 |
+
|
| 79 |
+
gr.Markdown("--- \n ## Dataset & Model Management")
|
| 80 |
+
gr.Markdown("To train on your own data, upload a CSV file with the following columns (no header required, or header ignored if present):\n1. **Anchor**: A fixed anchor phrase, `MY_FAVORITE_NEWS`.\n2. **Positive**: A title or contents that you like.\n3. **Negative**: A title or contents that you don't like.\n\nExample CSV Row:\n```\nMY_FAVORITE_NEWS,What is machine learning?,How to write a compiler from scratch.\n```")
|
| 81 |
+
import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
|
| 82 |
+
|
| 83 |
+
with gr.Row():
|
| 84 |
+
download_dataset_btn = gr.Button("πΎ Export Dataset")
|
| 85 |
+
download_model_btn = gr.Button("β¬οΈ Download Fine-Tuned Model")
|
| 86 |
+
|
| 87 |
+
download_status = gr.Markdown("Ready.")
|
| 88 |
+
|
| 89 |
+
with gr.Row():
|
| 90 |
+
dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
|
| 91 |
+
model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
|
| 92 |
+
|
| 93 |
+
gr.Markdown("### βοΈ Publish to Hugging Face Hub")
|
| 94 |
+
with gr.Row():
|
| 95 |
+
repo_name_input = gr.Textbox(label="Target Repository Name", placeholder="e.g., my-news-vibe-model")
|
| 96 |
+
push_to_hub_btn = gr.Button("Push to Hub", variant="secondary")
|
| 97 |
+
|
| 98 |
+
push_status = gr.Markdown("")
|
| 99 |
+
|
| 100 |
+
# --- Interactions ---
|
| 101 |
+
|
| 102 |
+
# 1. Initial Load: Initialize State and Load Data
|
| 103 |
+
demo.load(
|
| 104 |
+
fn=refresh_wrapper,
|
| 105 |
+
inputs=[session_state],
|
| 106 |
+
outputs=[session_state, favorite_list, output]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
buttons_to_lock = [
|
| 110 |
+
clear_reload_btn,
|
| 111 |
+
run_training_btn,
|
| 112 |
+
download_dataset_btn,
|
| 113 |
+
download_model_btn,
|
| 114 |
+
push_to_hub_btn
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
# 2. Buttons
|
| 118 |
+
clear_reload_btn.click(
|
| 119 |
+
fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
|
| 120 |
+
outputs=buttons_to_lock
|
| 121 |
+
).then(
|
| 122 |
+
fn=refresh_wrapper,
|
| 123 |
+
inputs=[session_state],
|
| 124 |
+
outputs=[session_state, favorite_list, output]
|
| 125 |
+
).then(
|
| 126 |
+
fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
|
| 127 |
+
outputs=buttons_to_lock
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
run_training_btn.click(
|
| 131 |
+
fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
|
| 132 |
+
outputs=buttons_to_lock
|
| 133 |
+
).then(
|
| 134 |
+
fn=training_wrapper,
|
| 135 |
+
inputs=[session_state, favorite_list],
|
| 136 |
+
outputs=[output]
|
| 137 |
+
).then(
|
| 138 |
+
fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
|
| 139 |
+
outputs=buttons_to_lock
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
import_file.change(
|
| 143 |
+
fn=import_wrapper,
|
| 144 |
+
inputs=[session_state, import_file],
|
| 145 |
+
outputs=[download_status]
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
download_dataset_btn.click(
|
| 149 |
+
fn=export_wrapper,
|
| 150 |
+
inputs=[session_state],
|
| 151 |
+
outputs=[dataset_output]
|
| 152 |
+
).then(
|
| 153 |
+
lambda p: gr.update(visible=True) if p else gr.update(), inputs=[dataset_output], outputs=[dataset_output]
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
download_model_btn.click(
|
| 157 |
+
fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
|
| 158 |
+
outputs=buttons_to_lock
|
| 159 |
+
).then(
|
| 160 |
+
lambda: [gr.update(value=None, visible=False), "Zipping..."], None, [model_output, download_status], queue=False
|
| 161 |
+
).then(
|
| 162 |
+
fn=download_model_wrapper,
|
| 163 |
+
inputs=[session_state],
|
| 164 |
+
outputs=[model_output]
|
| 165 |
+
).then(
|
| 166 |
+
lambda p: [gr.update(visible=p is not None, value=p), "ZIP ready." if p else "Zipping failed."], [model_output], [model_output, download_status]
|
| 167 |
+
).then(
|
| 168 |
+
fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
|
| 169 |
+
outputs=buttons_to_lock
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Push to Hub Interaction
|
| 173 |
+
push_to_hub_btn.click(
|
| 174 |
+
fn=push_to_hub_wrapper,
|
| 175 |
+
inputs=[session_state, repo_name_input],
|
| 176 |
+
outputs=[push_status]
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
with gr.Tab("π° Hacker News Similarity Check"):
|
| 180 |
+
with gr.Column():
|
| 181 |
+
gr.Markdown(f"## Live Hacker News Feed Vibe")
|
| 182 |
+
gr.Markdown(f"This feed uses the current model (base or fine-tuned) to score the vibe of live Hacker News stories against **`{AppConfig.QUERY_ANCHOR}`**.")
|
| 183 |
+
feed_output = gr.Markdown(value="Click 'Refresh Feed' to load stories.", label="Latest Stories")
|
| 184 |
+
refresh_button = gr.Button("Refresh Feed π", size="lg", variant="primary")
|
| 185 |
+
refresh_button.click(fn=mood_feed_wrapper, inputs=[session_state], outputs=feed_output)
|
| 186 |
+
|
| 187 |
+
with gr.Tab("π‘ Similarity Lamp"):
|
| 188 |
+
with gr.Column():
|
| 189 |
+
gr.Markdown(f"## News Similarity Check")
|
| 190 |
+
gr.Markdown(f"Enter text to see its similarity to **`{AppConfig.QUERY_ANCHOR}`**.\n**Vibe Key:** Green = High, Red = Low")
|
| 191 |
+
news_input = gr.Textbox(label="Enter News Title or Summary", lines=3)
|
| 192 |
+
vibe_check_btn = gr.Button("Check Similarity", variant="primary")
|
| 193 |
+
|
| 194 |
+
gr.Examples(
|
| 195 |
+
examples=[
|
| 196 |
+
"Global Markets Rally as Inflation Data Shows Unexpected Drop for Third Consecutive Month",
|
| 197 |
+
"Astronomers Detect Strong Oxygen Signature on Potentially Habitable Exoplanet",
|
| 198 |
+
"City Council Approves Controversial Plan to Ban Cars from Downtown District by 2027",
|
| 199 |
+
"Tech Giant Unveils Prototype for \"Invisible\" AR Glasses, Promising a Screen-Free Future",
|
| 200 |
+
"Local Library Receives Overdue Book Checked Out in 1948 With An Anonymous Apology Note"
|
| 201 |
+
],
|
| 202 |
+
inputs=news_input,
|
| 203 |
+
label="Try these examples"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
session_info_display = gr.Markdown()
|
| 207 |
+
|
| 208 |
+
with gr.Row():
|
| 209 |
+
vibe_color_block = gr.HTML(value='<div style="background-color: gray; height: 100px;"></div>', label="Mood Lamp")
|
| 210 |
+
with gr.Column():
|
| 211 |
+
vibe_score = gr.Textbox(label="Score", value="N/A", interactive=False)
|
| 212 |
+
vibe_status = gr.Textbox(label="Status", value="...", interactive=False)
|
| 213 |
+
|
| 214 |
+
vibe_check_btn.click(
|
| 215 |
+
fn=vibe_check_wrapper,
|
| 216 |
+
inputs=[session_state, news_input],
|
| 217 |
+
outputs=[vibe_score, vibe_status, vibe_color_block, session_info_display]
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
return demo
|
src/vibe_logic.py
CHANGED
|
@@ -23,7 +23,7 @@ VIBE_THRESHOLDS: List[VibeThreshold] = [
|
|
| 23 |
VibeThreshold(score=0.8, status="β¨ VIBE:HIGH"),
|
| 24 |
VibeThreshold(score=0.5, status="π VIBE:GOOD"),
|
| 25 |
VibeThreshold(score=0.2, status="π VIBE:FLAT"),
|
| 26 |
-
VibeThreshold(score=0.0, status="π VIBE:LOW
|
| 27 |
]
|
| 28 |
|
| 29 |
# --- Utility Functions ---
|
|
|
|
| 23 |
VibeThreshold(score=0.8, status="β¨ VIBE:HIGH"),
|
| 24 |
VibeThreshold(score=0.5, status="π VIBE:GOOD"),
|
| 25 |
VibeThreshold(score=0.2, status="π VIBE:FLAT"),
|
| 26 |
+
VibeThreshold(score=0.0, status="π VIBE:LOW"), # Base case for scores < 0.2
|
| 27 |
]
|
| 28 |
|
| 29 |
# --- Utility Functions ---
|