Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
| 3 |
import shutil
|
| 4 |
import time
|
| 5 |
import csv
|
|
|
|
| 6 |
from itertools import cycle
|
| 7 |
from typing import List, Iterable, Tuple, Optional, Callable
|
| 8 |
from datetime import datetime
|
|
@@ -19,46 +20,47 @@ from config import AppConfig
|
|
| 19 |
from vibe_logic import VibeChecker
|
| 20 |
from sentence_transformers import SentenceTransformer
|
| 21 |
|
| 22 |
-
# --- Main Application Class ---
|
| 23 |
|
| 24 |
class HackerNewsFineTuner:
|
| 25 |
"""
|
| 26 |
-
Encapsulates all application logic and state for
|
| 27 |
-
Manages the embedding model, news data, and training datasets.
|
| 28 |
"""
|
| 29 |
|
| 30 |
def __init__(self, config: AppConfig = AppConfig):
|
| 31 |
# --- Dependencies ---
|
| 32 |
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# --- Application State ---
|
| 35 |
self.model: Optional[SentenceTransformer] = None
|
| 36 |
self.vibe_checker: Optional[VibeChecker] = None
|
| 37 |
-
self.titles: List[str] = []
|
| 38 |
-
self.target_titles: List[str] = []
|
| 39 |
-
self.number_list: List[int] = []
|
| 40 |
-
self.last_hn_dataset: List[List[str]] = []
|
| 41 |
-
self.imported_dataset: List[List[str]] = []
|
| 42 |
-
|
| 43 |
-
# Setup
|
| 44 |
-
os.makedirs(self.config.ARTIFACTS_DIR, exist_ok=True)
|
| 45 |
-
print(f"Created artifact directory: {self.config.ARTIFACTS_DIR}")
|
| 46 |
-
|
| 47 |
-
authenticate_hf(self.config.HF_TOKEN)
|
| 48 |
|
| 49 |
-
#
|
| 50 |
-
self.
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
print("--- Running Initial Data Load ---")
|
| 55 |
-
self.refresh_data_and_model()
|
| 56 |
-
print("--- Initial Load Complete ---")
|
| 57 |
|
| 58 |
def _update_vibe_checker(self):
|
| 59 |
"""Initializes or updates the VibeChecker with the current model state."""
|
| 60 |
if self.model:
|
| 61 |
-
print("Updating VibeChecker instance with the current model.")
|
| 62 |
self.vibe_checker = VibeChecker(
|
| 63 |
model=self.model,
|
| 64 |
query_anchor=self.config.QUERY_ANCHOR,
|
|
@@ -71,14 +73,10 @@ class HackerNewsFineTuner:
|
|
| 71 |
|
| 72 |
def refresh_data_and_model(self) -> Tuple[gr.update, gr.update]:
|
| 73 |
"""
|
| 74 |
-
|
| 75 |
-
2. Fetches fresh news data (from cache or web).
|
| 76 |
-
3. Updates the class state and returns Gradio updates for the UI.
|
| 77 |
"""
|
| 78 |
-
print("
|
| 79 |
-
print("RELOADING MODEL and RE-FETCHING DATA")
|
| 80 |
|
| 81 |
-
# Reset dataset state
|
| 82 |
self.last_hn_dataset = []
|
| 83 |
self.imported_dataset = []
|
| 84 |
|
|
@@ -87,33 +85,32 @@ class HackerNewsFineTuner:
|
|
| 87 |
self.model = load_embedding_model(self.config.MODEL_NAME)
|
| 88 |
self._update_vibe_checker()
|
| 89 |
except Exception as e:
|
| 90 |
-
|
|
|
|
| 91 |
self.model = None
|
| 92 |
self._update_vibe_checker()
|
| 93 |
return (
|
| 94 |
gr.update(choices=[], label="Model Load Failed"),
|
| 95 |
-
gr.update(value=
|
| 96 |
)
|
| 97 |
|
| 98 |
# 2. Fetch fresh news data
|
|
|
|
| 99 |
news_feed, status_msg = read_hacker_news_rss(self.config)
|
| 100 |
titles_out, target_titles_out = [], []
|
| 101 |
-
status_value: str = f"
|
| 102 |
|
| 103 |
if news_feed is not None and news_feed.entries:
|
| 104 |
-
# Use constant for clarity
|
| 105 |
titles_out = [item.title for item in news_feed.entries[:self.config.TOP_TITLES_COUNT]]
|
| 106 |
target_titles_out = [item.title for item in news_feed.entries[self.config.TOP_TITLES_COUNT:]]
|
| 107 |
-
print(f"Data reloaded: {len(titles_out)} selection titles, {len(target_titles_out)} target titles.")
|
| 108 |
else:
|
| 109 |
-
titles_out = ["Error fetching news
|
| 110 |
-
gr.Warning(f"Data reload failed.
|
| 111 |
|
| 112 |
self.titles = titles_out
|
| 113 |
self.target_titles = target_titles_out
|
| 114 |
self.number_list = list(range(len(self.titles)))
|
| 115 |
|
| 116 |
-
# Return Gradio updates for CheckboxGroup and Textbox
|
| 117 |
return (
|
| 118 |
gr.update(
|
| 119 |
choices=self.titles,
|
|
@@ -142,55 +139,52 @@ class HackerNewsFineTuner:
|
|
| 142 |
new_dataset.append([s.strip() for s in row])
|
| 143 |
num_imported += 1
|
| 144 |
if num_imported == 0:
|
| 145 |
-
raise ValueError("No valid
|
| 146 |
self.imported_dataset = new_dataset
|
| 147 |
-
return f"
|
| 148 |
except Exception as e:
|
| 149 |
-
|
| 150 |
-
return "Import failed. Check console for details."
|
| 151 |
|
| 152 |
def export_dataset(self) -> Optional[str]:
|
| 153 |
if not self.last_hn_dataset:
|
| 154 |
-
gr.Warning("No dataset
|
| 155 |
return None
|
| 156 |
-
|
|
|
|
|
|
|
| 157 |
try:
|
| 158 |
-
print(f"Exporting dataset to {file_path}...")
|
| 159 |
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
| 160 |
writer = csv.writer(f)
|
| 161 |
writer.writerow(['Anchor', 'Positive', 'Negative'])
|
| 162 |
writer.writerows(self.last_hn_dataset)
|
| 163 |
-
gr.Info(f"Dataset
|
| 164 |
return str(file_path)
|
| 165 |
except Exception as e:
|
| 166 |
-
gr.Error(f"
|
| 167 |
return None
|
| 168 |
|
| 169 |
def download_model(self) -> Optional[str]:
|
| 170 |
-
if not os.path.exists(self.
|
| 171 |
-
gr.Warning(
|
| 172 |
return None
|
|
|
|
| 173 |
timestamp = int(time.time())
|
| 174 |
try:
|
| 175 |
-
|
|
|
|
| 176 |
archive_path = shutil.make_archive(
|
| 177 |
-
base_name=base_name,
|
| 178 |
format='zip',
|
| 179 |
-
root_dir=self.
|
| 180 |
)
|
| 181 |
-
gr.Info(f"Model
|
| 182 |
return archive_path
|
| 183 |
except Exception as e:
|
| 184 |
-
gr.Error(f"
|
| 185 |
return None
|
| 186 |
|
| 187 |
## Training Logic ##
|
| 188 |
def _create_hn_dataset(self, selected_ids: List[int]) -> Tuple[List[List[str]], str, str]:
|
| 189 |
-
"""
|
| 190 |
-
Internal function to generate the [Anchor, Positive, Negative] triplets
|
| 191 |
-
from the user's Hacker News title selection.
|
| 192 |
-
Returns (dataset, favorite_title, non_favorite_title)
|
| 193 |
-
"""
|
| 194 |
total_ids, selected_ids = set(self.number_list), set(selected_ids)
|
| 195 |
non_selected_ids = total_ids - selected_ids
|
| 196 |
is_minority = len(selected_ids) < (len(total_ids) / 2)
|
|
@@ -200,6 +194,9 @@ class HackerNewsFineTuner:
|
|
| 200 |
def get_titles(anchor_id, pool_id):
|
| 201 |
return (self.titles[pool_id], self.titles[anchor_id]) if is_minority else (self.titles[anchor_id], self.titles[pool_id])
|
| 202 |
|
|
|
|
|
|
|
|
|
|
| 203 |
fav_idx = pool_ids[0] if is_minority else list(anchor_ids)[0]
|
| 204 |
non_fav_idx = list(anchor_ids)[0] if is_minority else pool_ids[0]
|
| 205 |
|
|
@@ -212,63 +209,66 @@ class HackerNewsFineTuner:
|
|
| 212 |
return hn_dataset, self.titles[fav_idx], self.titles[non_fav_idx]
|
| 213 |
|
| 214 |
def training(self, selected_ids: List[int]) -> str:
|
| 215 |
-
"""
|
| 216 |
-
Generates a training dataset from user selection and runs the fine-tuning process.
|
| 217 |
-
"""
|
| 218 |
if self.model is None:
|
| 219 |
-
raise gr.Error("
|
| 220 |
if not selected_ids:
|
| 221 |
-
raise gr.Error("
|
| 222 |
if len(selected_ids) == len(self.number_list):
|
| 223 |
-
raise gr.Error("
|
| 224 |
|
| 225 |
-
hn_dataset,
|
| 226 |
self.last_hn_dataset = hn_dataset
|
| 227 |
final_dataset = self.last_hn_dataset + self.imported_dataset
|
|
|
|
| 228 |
if not final_dataset:
|
| 229 |
-
raise gr.Error("
|
| 230 |
-
print(f"Combined dataset size: {len(final_dataset)} triplets.")
|
| 231 |
|
| 232 |
def semantic_search_fn() -> str:
|
| 233 |
return get_top_hits(model=self.model, target_titles=self.target_titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
|
| 234 |
|
| 235 |
-
result = "###
|
| 236 |
-
print("
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
self._update_vibe_checker()
|
| 239 |
-
print("
|
| 240 |
|
| 241 |
-
result += "###
|
| 242 |
return result
|
| 243 |
|
| 244 |
-
## Vibe Check Logic
|
| 245 |
def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
|
| 246 |
if not self.vibe_checker:
|
| 247 |
-
|
| 248 |
-
return "N/A", "Model Error", gr.update(value=self._generate_vibe_html("gray"))
|
| 249 |
if not news_text or len(news_text.split()) < 3:
|
| 250 |
-
|
| 251 |
-
return "N/A", "Please enter text", gr.update(value=self._generate_vibe_html("white"))
|
| 252 |
|
| 253 |
try:
|
| 254 |
vibe_result = self.vibe_checker.check(news_text)
|
| 255 |
-
status = vibe_result.status_html.split('>')[1].split('<')[0]
|
| 256 |
return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_html(vibe_result.color_hsl))
|
| 257 |
except Exception as e:
|
| 258 |
-
|
| 259 |
-
return "N/A", f"Processing Error: {e}", gr.update(value=self._generate_vibe_html("gray"))
|
| 260 |
|
| 261 |
def _generate_vibe_html(self, color: str) -> str:
|
| 262 |
return f'<div style="background-color: {color}; height: 100px; border-radius: 12px; border: 2px solid #ccc;"></div>'
|
| 263 |
|
| 264 |
-
## Mood Reader Logic
|
| 265 |
def fetch_and_display_mood_feed(self) -> str:
|
| 266 |
if not self.vibe_checker:
|
| 267 |
-
return "
|
| 268 |
|
| 269 |
feed, status = read_hacker_news_rss(self.config)
|
| 270 |
if not feed or not feed.entries:
|
| 271 |
-
return f"**
|
| 272 |
|
| 273 |
scored_entries = []
|
| 274 |
for entry in feed.entries:
|
|
@@ -286,8 +286,8 @@ class HackerNewsFineTuner:
|
|
| 286 |
|
| 287 |
scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True)
|
| 288 |
|
| 289 |
-
md = (f"## Hacker News
|
| 290 |
-
f"**
|
| 291 |
"| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n")
|
| 292 |
|
| 293 |
for item in scored_entries:
|
|
@@ -297,94 +297,140 @@ class HackerNewsFineTuner:
|
|
| 297 |
f"| [Comments]({item['comments']}) "
|
| 298 |
f"| {item['published']} |\n")
|
| 299 |
return md
|
| 300 |
-
# 🤖 Embedding Gemma Modkit: Fine-Tuning and Mood Reader
|
| 301 |
-
|
| 302 |
-
## Gradio Interface Setup ##
|
| 303 |
-
def build_interface(self) -> gr.Blocks:
|
| 304 |
-
with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
|
| 305 |
-
gr.Markdown("# 🤖 EmbeddingGemma Modkit: Fine-Tuning and Mood Reader")
|
| 306 |
-
gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/README.md) for more details.")
|
| 307 |
-
with gr.Tab("🚀 Fine-Tuning & Evaluation"):
|
| 308 |
-
self._build_training_interface()
|
| 309 |
-
with gr.Tab("📰 Hacker News Mood Reader"):
|
| 310 |
-
self._build_mood_reader_interface()
|
| 311 |
-
with gr.Tab("💡 Similarity Check"):
|
| 312 |
-
self._build_vibe_check_interface()
|
| 313 |
-
return demo
|
| 314 |
-
|
| 315 |
-
def _build_training_interface(self):
|
| 316 |
-
with gr.Column():
|
| 317 |
-
gr.Markdown("## Fine-Tuning & Semantic Search\nSelect titles to fine-tune the model towards making them more similar to **`MY_FAVORITE_NEWS`**.")
|
| 318 |
-
with gr.Row():
|
| 319 |
-
favorite_list = gr.CheckboxGroup(self.titles, type="index", label=f"Hacker News Top {len(self.titles)}", show_select_all=True)
|
| 320 |
-
output = gr.Textbox(lines=14, label="Training and Search Results", value="Click 'Run Fine-Tuning' to begin.")
|
| 321 |
-
with gr.Row():
|
| 322 |
-
clear_reload_btn = gr.Button("Clear & Reload Model/Data")
|
| 323 |
-
run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary")
|
| 324 |
-
gr.Markdown("--- \n ## Dataset & Model Management")
|
| 325 |
-
gr.Markdown("To train on your own data, upload a CSV file with the following columns (no header required, or header ignored if present):\n1. **Anchor**: A fixed anchor phrase, `MY_FAVORITE_NEWS`.\n2. **Positive**: A title or contents that you like.\n3. **Negative**: A title or contents that you don't like.\n\nExample CSV Row:\n```\nMY_FAVORITE_NEWS,What is machine learning?,How to write a compiler from scratch.\n```")
|
| 326 |
-
import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
|
| 327 |
-
with gr.Row():
|
| 328 |
-
download_dataset_btn = gr.Button("💾 Export Last HN Dataset")
|
| 329 |
-
download_model_btn = gr.Button("⬇️ Download Fine-Tuned Model")
|
| 330 |
-
download_status = gr.Markdown("Ready.")
|
| 331 |
-
with gr.Row():
|
| 332 |
-
dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
|
| 333 |
-
model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
|
| 334 |
-
|
| 335 |
-
buttons_to_lock = [
|
| 336 |
-
clear_reload_btn,
|
| 337 |
-
run_training_btn,
|
| 338 |
-
download_dataset_btn,
|
| 339 |
-
download_model_btn
|
| 340 |
-
]
|
| 341 |
-
|
| 342 |
-
run_training_btn.click(
|
| 343 |
-
fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
|
| 344 |
-
outputs=buttons_to_lock
|
| 345 |
-
).then(
|
| 346 |
-
fn=self.training, inputs=favorite_list, outputs=output
|
| 347 |
-
).then(
|
| 348 |
-
fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
|
| 349 |
-
outputs=buttons_to_lock
|
| 350 |
-
)
|
| 351 |
-
clear_reload_btn.click(fn=self.refresh_data_and_model, inputs=None, outputs=[favorite_list, output], queue=False)
|
| 352 |
-
import_file.change(fn=self.import_additional_dataset, inputs=[import_file], outputs=download_status)
|
| 353 |
-
download_dataset_btn.click(lambda: [gr.update(value=None, visible=False), "Generating..."], None, [dataset_output, download_status], queue=False).then(self.export_dataset, None, dataset_output).then(lambda p: [gr.update(visible=p is not None, value=p), "CSV ready." if p else "Export failed."], [dataset_output], [dataset_output, download_status])
|
| 354 |
-
download_model_btn.click(
|
| 355 |
-
fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
|
| 356 |
-
outputs=buttons_to_lock
|
| 357 |
-
).then(
|
| 358 |
-
lambda: [gr.update(value=None, visible=False), "Zipping..."], None, [model_output, download_status], queue=False
|
| 359 |
-
).then(self.download_model, None, model_output).then(lambda p: [gr.update(visible=p is not None, value=p), "ZIP ready." if p else "Zipping failed."], [model_output], [model_output, download_status]
|
| 360 |
-
).then(
|
| 361 |
-
fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
|
| 362 |
-
outputs=buttons_to_lock
|
| 363 |
-
)
|
| 364 |
|
| 365 |
-
def _build_vibe_check_interface(self):
|
| 366 |
-
with gr.Column():
|
| 367 |
-
gr.Markdown(f"## News Vibe Check Mood Lamp\nEnter text to see its similarity to **`{self.config.QUERY_ANCHOR}`**.\n**Vibe Key:** Green = High, Red = Low")
|
| 368 |
-
news_input = gr.Textbox(label="Enter News Title or Summary", lines=3)
|
| 369 |
-
vibe_check_btn = gr.Button("Check Vibe", variant="primary")
|
| 370 |
-
with gr.Row():
|
| 371 |
-
vibe_color_block = gr.HTML(value=self._generate_vibe_html("white"), label="Mood Lamp")
|
| 372 |
-
with gr.Column():
|
| 373 |
-
vibe_score = gr.Textbox(label="Cosine Similarity Score", value="N/A", interactive=False)
|
| 374 |
-
vibe_status = gr.Textbox(label="Vibe Status", value="Enter text and click 'Check Vibe'", interactive=False, lines=2)
|
| 375 |
-
vibe_check_btn.click(fn=self.get_vibe_check, inputs=[news_input], outputs=[vibe_score, vibe_status, vibe_color_block])
|
| 376 |
-
|
| 377 |
-
def _build_mood_reader_interface(self):
|
| 378 |
-
with gr.Column():
|
| 379 |
-
gr.Markdown(f"## Live Hacker News Feed Vibe\nThis feed uses the current model (base or fine-tuned) to score the vibe of live HN stories against **`{self.config.QUERY_ANCHOR}`**.")
|
| 380 |
-
feed_output = gr.Markdown(value="Click 'Refresh Feed' to load stories.", label="Latest Stories")
|
| 381 |
-
refresh_button = gr.Button("Refresh Feed 🔄", size="lg", variant="primary")
|
| 382 |
-
refresh_button.click(fn=self.fetch_and_display_mood_feed, inputs=None, outputs=feed_output)
|
| 383 |
|
|
|
|
|
|
|
| 384 |
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import shutil
|
| 4 |
import time
|
| 5 |
import csv
|
| 6 |
+
import uuid
|
| 7 |
from itertools import cycle
|
| 8 |
from typing import List, Iterable, Tuple, Optional, Callable
|
| 9 |
from datetime import datetime
|
|
|
|
| 20 |
from vibe_logic import VibeChecker
|
| 21 |
from sentence_transformers import SentenceTransformer
|
| 22 |
|
| 23 |
+
# --- Main Application Class (Session Scoped) ---
|
| 24 |
|
| 25 |
class HackerNewsFineTuner:
|
| 26 |
"""
|
| 27 |
+
Encapsulates all application logic and state for a single user session.
|
|
|
|
| 28 |
"""
|
| 29 |
|
| 30 |
def __init__(self, config: AppConfig = AppConfig):
|
| 31 |
# --- Dependencies ---
|
| 32 |
self.config = config
|
| 33 |
+
|
| 34 |
+
# --- Session Identification ---
|
| 35 |
+
self.session_id = str(uuid.uuid4())
|
| 36 |
+
|
| 37 |
+
# Define session-specific paths to allow simultaneous training
|
| 38 |
+
self.session_root = self.config.ARTIFACTS_DIR / self.session_id
|
| 39 |
+
self.output_dir = self.session_root / "embedding_gemma_finetuned"
|
| 40 |
+
self.dataset_export_file = self.session_root / "training_dataset.csv"
|
| 41 |
+
|
| 42 |
+
# Setup directories
|
| 43 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 44 |
+
print(f"[{self.session_id}] New session started. Artifacts: {self.session_root}")
|
| 45 |
|
| 46 |
# --- Application State ---
|
| 47 |
self.model: Optional[SentenceTransformer] = None
|
| 48 |
self.vibe_checker: Optional[VibeChecker] = None
|
| 49 |
+
self.titles: List[str] = []
|
| 50 |
+
self.target_titles: List[str] = []
|
| 51 |
+
self.number_list: List[int] = []
|
| 52 |
+
self.last_hn_dataset: List[List[str]] = []
|
| 53 |
+
self.imported_dataset: List[List[str]] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
# Authenticate once (global)
|
| 56 |
+
authenticate_hf(self.config.HF_TOKEN)
|
| 57 |
|
| 58 |
+
# Note: We do NOT load data here immediately to keep init fast.
|
| 59 |
+
# Data is loaded via the demo.load event.
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def _update_vibe_checker(self):
|
| 62 |
"""Initializes or updates the VibeChecker with the current model state."""
|
| 63 |
if self.model:
|
|
|
|
| 64 |
self.vibe_checker = VibeChecker(
|
| 65 |
model=self.model,
|
| 66 |
query_anchor=self.config.QUERY_ANCHOR,
|
|
|
|
| 73 |
|
| 74 |
def refresh_data_and_model(self) -> Tuple[gr.update, gr.update]:
|
| 75 |
"""
|
| 76 |
+
Reloads model and fetches data.
|
|
|
|
|
|
|
| 77 |
"""
|
| 78 |
+
print(f"[{self.session_id}] Reloading model and data...")
|
|
|
|
| 79 |
|
|
|
|
| 80 |
self.last_hn_dataset = []
|
| 81 |
self.imported_dataset = []
|
| 82 |
|
|
|
|
| 85 |
self.model = load_embedding_model(self.config.MODEL_NAME)
|
| 86 |
self._update_vibe_checker()
|
| 87 |
except Exception as e:
|
| 88 |
+
error_msg = f"CRITICAL ERROR: Model failed to load. {e}"
|
| 89 |
+
print(error_msg)
|
| 90 |
self.model = None
|
| 91 |
self._update_vibe_checker()
|
| 92 |
return (
|
| 93 |
gr.update(choices=[], label="Model Load Failed"),
|
| 94 |
+
gr.update(value=error_msg)
|
| 95 |
)
|
| 96 |
|
| 97 |
# 2. Fetch fresh news data
|
| 98 |
+
# Note: Cache file is shared (global), which is fine/desired for RSS data.
|
| 99 |
news_feed, status_msg = read_hacker_news_rss(self.config)
|
| 100 |
titles_out, target_titles_out = [], []
|
| 101 |
+
status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"
|
| 102 |
|
| 103 |
if news_feed is not None and news_feed.entries:
|
|
|
|
| 104 |
titles_out = [item.title for item in news_feed.entries[:self.config.TOP_TITLES_COUNT]]
|
| 105 |
target_titles_out = [item.title for item in news_feed.entries[self.config.TOP_TITLES_COUNT:]]
|
|
|
|
| 106 |
else:
|
| 107 |
+
titles_out = ["Error fetching news.", "Check console."]
|
| 108 |
+
gr.Warning(f"Data reload failed. {status_msg}")
|
| 109 |
|
| 110 |
self.titles = titles_out
|
| 111 |
self.target_titles = target_titles_out
|
| 112 |
self.number_list = list(range(len(self.titles)))
|
| 113 |
|
|
|
|
| 114 |
return (
|
| 115 |
gr.update(
|
| 116 |
choices=self.titles,
|
|
|
|
| 139 |
new_dataset.append([s.strip() for s in row])
|
| 140 |
num_imported += 1
|
| 141 |
if num_imported == 0:
|
| 142 |
+
raise ValueError("No valid rows found.")
|
| 143 |
self.imported_dataset = new_dataset
|
| 144 |
+
return f"Imported {num_imported} triplets."
|
| 145 |
except Exception as e:
|
| 146 |
+
return f"Import failed: {e}"
|
|
|
|
| 147 |
|
| 148 |
def export_dataset(self) -> Optional[str]:
|
| 149 |
if not self.last_hn_dataset:
|
| 150 |
+
gr.Warning("No dataset generated yet.")
|
| 151 |
return None
|
| 152 |
+
|
| 153 |
+
# Use session-specific path
|
| 154 |
+
file_path = self.dataset_export_file
|
| 155 |
try:
|
|
|
|
| 156 |
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
| 157 |
writer = csv.writer(f)
|
| 158 |
writer.writerow(['Anchor', 'Positive', 'Negative'])
|
| 159 |
writer.writerows(self.last_hn_dataset)
|
| 160 |
+
gr.Info(f"Dataset exported.")
|
| 161 |
return str(file_path)
|
| 162 |
except Exception as e:
|
| 163 |
+
gr.Error(f"Export failed: {e}")
|
| 164 |
return None
|
| 165 |
|
| 166 |
def download_model(self) -> Optional[str]:
|
| 167 |
+
if not os.path.exists(self.output_dir):
|
| 168 |
+
gr.Warning("No model trained yet.")
|
| 169 |
return None
|
| 170 |
+
|
| 171 |
timestamp = int(time.time())
|
| 172 |
try:
|
| 173 |
+
# Create zip in the session folder
|
| 174 |
+
base_name = self.session_root / f"model_finetuned_{timestamp}"
|
| 175 |
archive_path = shutil.make_archive(
|
| 176 |
+
base_name=str(base_name),
|
| 177 |
format='zip',
|
| 178 |
+
root_dir=self.output_dir,
|
| 179 |
)
|
| 180 |
+
gr.Info(f"Model zipped.")
|
| 181 |
return archive_path
|
| 182 |
except Exception as e:
|
| 183 |
+
gr.Error(f"Zip failed: {e}")
|
| 184 |
return None
|
| 185 |
|
| 186 |
## Training Logic ##
|
| 187 |
def _create_hn_dataset(self, selected_ids: List[int]) -> Tuple[List[List[str]], str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
total_ids, selected_ids = set(self.number_list), set(selected_ids)
|
| 189 |
non_selected_ids = total_ids - selected_ids
|
| 190 |
is_minority = len(selected_ids) < (len(total_ids) / 2)
|
|
|
|
| 194 |
def get_titles(anchor_id, pool_id):
|
| 195 |
return (self.titles[pool_id], self.titles[anchor_id]) if is_minority else (self.titles[anchor_id], self.titles[pool_id])
|
| 196 |
|
| 197 |
+
if not pool_ids or not anchor_ids:
|
| 198 |
+
return [], "", "" # Should be caught by validation
|
| 199 |
+
|
| 200 |
fav_idx = pool_ids[0] if is_minority else list(anchor_ids)[0]
|
| 201 |
non_fav_idx = list(anchor_ids)[0] if is_minority else pool_ids[0]
|
| 202 |
|
|
|
|
| 209 |
return hn_dataset, self.titles[fav_idx], self.titles[non_fav_idx]
|
| 210 |
|
| 211 |
def training(self, selected_ids: List[int]) -> str:
|
|
|
|
|
|
|
|
|
|
| 212 |
if self.model is None:
|
| 213 |
+
raise gr.Error("Model not loaded.")
|
| 214 |
if not selected_ids:
|
| 215 |
+
raise gr.Error("Select at least one title.")
|
| 216 |
if len(selected_ids) == len(self.number_list):
|
| 217 |
+
raise gr.Error("Cannot select all titles.")
|
| 218 |
|
| 219 |
+
hn_dataset, _, _ = self._create_hn_dataset(selected_ids)
|
| 220 |
self.last_hn_dataset = hn_dataset
|
| 221 |
final_dataset = self.last_hn_dataset + self.imported_dataset
|
| 222 |
+
|
| 223 |
if not final_dataset:
|
| 224 |
+
raise gr.Error("Dataset is empty.")
|
|
|
|
| 225 |
|
| 226 |
def semantic_search_fn() -> str:
|
| 227 |
return get_top_hits(model=self.model, target_titles=self.target_titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
|
| 228 |
|
| 229 |
+
result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
|
| 230 |
+
print(f"[{self.session_id}] Starting Training...")
|
| 231 |
+
|
| 232 |
+
# Use session-specific output directory
|
| 233 |
+
train_with_dataset(
|
| 234 |
+
model=self.model,
|
| 235 |
+
dataset=final_dataset,
|
| 236 |
+
output_dir=self.output_dir,
|
| 237 |
+
task_name=self.config.TASK_NAME,
|
| 238 |
+
search_fn=semantic_search_fn
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
self._update_vibe_checker()
|
| 242 |
+
print(f"[{self.session_id}] Training Complete.")
|
| 243 |
|
| 244 |
+
result += "### Search (After):\n" + f"{semantic_search_fn()}"
|
| 245 |
return result
|
| 246 |
|
| 247 |
+
## Vibe Check Logic ##
|
| 248 |
def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
|
| 249 |
if not self.vibe_checker:
|
| 250 |
+
return "N/A", "Model Loading...", gr.update(value=self._generate_vibe_html("gray"))
|
|
|
|
| 251 |
if not news_text or len(news_text.split()) < 3:
|
| 252 |
+
return "N/A", "Text too short", gr.update(value=self._generate_vibe_html("white"))
|
|
|
|
| 253 |
|
| 254 |
try:
|
| 255 |
vibe_result = self.vibe_checker.check(news_text)
|
| 256 |
+
status = vibe_result.status_html.split('>')[1].split('<')[0]
|
| 257 |
return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_html(vibe_result.color_hsl))
|
| 258 |
except Exception as e:
|
| 259 |
+
return "N/A", f"Error: {e}", gr.update(value=self._generate_vibe_html("gray"))
|
|
|
|
| 260 |
|
| 261 |
def _generate_vibe_html(self, color: str) -> str:
|
| 262 |
return f'<div style="background-color: {color}; height: 100px; border-radius: 12px; border: 2px solid #ccc;"></div>'
|
| 263 |
|
| 264 |
+
## Mood Reader Logic ##
|
| 265 |
def fetch_and_display_mood_feed(self) -> str:
|
| 266 |
if not self.vibe_checker:
|
| 267 |
+
return "Model not ready. Please wait or reload."
|
| 268 |
|
| 269 |
feed, status = read_hacker_news_rss(self.config)
|
| 270 |
if not feed or not feed.entries:
|
| 271 |
+
return f"**Feed Error:** {status}"
|
| 272 |
|
| 273 |
scored_entries = []
|
| 274 |
for entry in feed.entries:
|
|
|
|
| 286 |
|
| 287 |
scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True)
|
| 288 |
|
| 289 |
+
md = (f"## Hacker News Mood (Session: {self.session_id[:6]})\n"
|
| 290 |
+
f"**Updated:** {datetime.now().strftime('%H:%M:%S')}\n\n"
|
| 291 |
"| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n")
|
| 292 |
|
| 293 |
for item in scored_entries:
|
|
|
|
| 297 |
f"| [Comments]({item['comments']}) "
|
| 298 |
f"| {item['published']} |\n")
|
| 299 |
return md
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
+
# --- Session Wrappers ---
|
| 303 |
+
# These functions act as bridges between Gradio inputs and the session object.
|
| 304 |
|
| 305 |
+
def create_session():
|
| 306 |
+
"""Factory to create a new session object."""
|
| 307 |
+
return HackerNewsFineTuner(AppConfig)
|
| 308 |
+
|
| 309 |
+
def refresh_wrapper(app):
|
| 310 |
+
return app.refresh_data_and_model()
|
| 311 |
+
|
| 312 |
+
def import_wrapper(app, file):
|
| 313 |
+
return app.import_additional_dataset(file)
|
| 314 |
+
|
| 315 |
+
def export_wrapper(app):
|
| 316 |
+
return app.export_dataset()
|
| 317 |
+
|
| 318 |
+
def download_model_wrapper(app):
|
| 319 |
+
return app.download_model()
|
| 320 |
+
|
| 321 |
+
def training_wrapper(app, selected_ids):
|
| 322 |
+
return app.training(selected_ids)
|
| 323 |
|
| 324 |
+
def vibe_check_wrapper(app, text):
|
| 325 |
+
return app.get_vibe_check(text)
|
| 326 |
+
|
| 327 |
+
def mood_feed_wrapper(app):
|
| 328 |
+
return app.fetch_and_display_mood_feed()
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# --- Interface Setup ---
|
| 332 |
+
|
| 333 |
+
def build_interface() -> gr.Blocks:
|
| 334 |
+
with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
|
| 335 |
+
# State object holds the user-specific instance of HackerNewsFineTuner
|
| 336 |
+
session_state = gr.State(create_session)
|
| 337 |
+
|
| 338 |
+
gr.Markdown("# 🤖 EmbeddingGemma Modkit: Fine-Tuning and Mood Reader (Multi-User)")
|
| 339 |
+
gr.Markdown("Each browser tab creates a unique session with isolated training data and models.")
|
| 340 |
+
|
| 341 |
+
with gr.Tab("🚀 Fine-Tuning & Evaluation"):
|
| 342 |
+
with gr.Column():
|
| 343 |
+
gr.Markdown("## Fine-Tuning & Semantic Search")
|
| 344 |
+
with gr.Row():
|
| 345 |
+
# Choices are populated on load via refresh_wrapper
|
| 346 |
+
favorite_list = gr.CheckboxGroup(choices=[], type="index", label="Hacker News Top Stories", show_select_all=True)
|
| 347 |
+
output = gr.Textbox(lines=14, label="Training and Search Results", value="Loading data...")
|
| 348 |
+
|
| 349 |
+
with gr.Row():
|
| 350 |
+
clear_reload_btn = gr.Button("Clear & Reload")
|
| 351 |
+
run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary")
|
| 352 |
+
|
| 353 |
+
gr.Markdown("--- \n ## Dataset & Model Management")
|
| 354 |
+
import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
|
| 355 |
+
|
| 356 |
+
with gr.Row():
|
| 357 |
+
download_dataset_btn = gr.Button("💾 Export Dataset")
|
| 358 |
+
download_model_btn = gr.Button("⬇️ Download Model")
|
| 359 |
+
|
| 360 |
+
download_status = gr.Markdown("Ready.")
|
| 361 |
+
|
| 362 |
+
with gr.Row():
|
| 363 |
+
dataset_output = gr.File(label="Dataset CSV", height=50, visible=False, interactive=False)
|
| 364 |
+
model_output = gr.File(label="Model ZIP", height=50, visible=False, interactive=False)
|
| 365 |
+
|
| 366 |
+
# Interactions
|
| 367 |
+
# Note: We pass session_state as the first input to all wrappers
|
| 368 |
+
|
| 369 |
+
# 1. Initial Load
|
| 370 |
+
demo.load(fn=refresh_wrapper, inputs=[session_state], outputs=[favorite_list, output])
|
| 371 |
+
|
| 372 |
+
# 2. Buttons
|
| 373 |
+
clear_reload_btn.click(
|
| 374 |
+
fn=refresh_wrapper,
|
| 375 |
+
inputs=[session_state],
|
| 376 |
+
outputs=[favorite_list, output]
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
run_training_btn.click(
|
| 380 |
+
fn=training_wrapper,
|
| 381 |
+
inputs=[session_state, favorite_list],
|
| 382 |
+
outputs=[output]
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
import_file.change(
|
| 386 |
+
fn=import_wrapper,
|
| 387 |
+
inputs=[session_state, import_file],
|
| 388 |
+
outputs=[download_status]
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
download_dataset_btn.click(
|
| 392 |
+
fn=export_wrapper,
|
| 393 |
+
inputs=[session_state],
|
| 394 |
+
outputs=[dataset_output]
|
| 395 |
+
).then(
|
| 396 |
+
lambda p: gr.update(visible=True) if p else gr.update(), inputs=[dataset_output], outputs=[dataset_output]
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
download_model_btn.click(
|
| 400 |
+
fn=download_model_wrapper,
|
| 401 |
+
inputs=[session_state],
|
| 402 |
+
outputs=[model_output]
|
| 403 |
+
).then(
|
| 404 |
+
lambda p: gr.update(visible=True) if p else gr.update(), inputs=[model_output], outputs=[model_output]
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
with gr.Tab("📰 Hacker News Mood Reader"):
|
| 408 |
+
with gr.Column():
|
| 409 |
+
gr.Markdown(f"## Live Hacker News Feed Vibe")
|
| 410 |
+
feed_output = gr.Markdown(value="Click 'Refresh Feed'...", label="Latest Stories")
|
| 411 |
+
refresh_button = gr.Button("Refresh Feed 🔄", size="lg", variant="primary")
|
| 412 |
+
refresh_button.click(fn=mood_feed_wrapper, inputs=[session_state], outputs=feed_output)
|
| 413 |
+
|
| 414 |
+
with gr.Tab("💡 Similarity Check"):
|
| 415 |
+
with gr.Column():
|
| 416 |
+
gr.Markdown(f"## News Vibe Check Mood Lamp")
|
| 417 |
+
news_input = gr.Textbox(label="Enter News Title or Summary", lines=3)
|
| 418 |
+
vibe_check_btn = gr.Button("Check Vibe", variant="primary")
|
| 419 |
+
with gr.Row():
|
| 420 |
+
vibe_color_block = gr.HTML(value='<div style="background-color: gray; height: 100px;"></div>', label="Mood Lamp")
|
| 421 |
+
with gr.Column():
|
| 422 |
+
vibe_score = gr.Textbox(label="Score", value="N/A", interactive=False)
|
| 423 |
+
vibe_status = gr.Textbox(label="Status", value="...", interactive=False)
|
| 424 |
+
|
| 425 |
+
vibe_check_btn.click(
|
| 426 |
+
fn=vibe_check_wrapper,
|
| 427 |
+
inputs=[session_state, news_input],
|
| 428 |
+
outputs=[vibe_score, vibe_status, vibe_color_block]
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
return demo
|
| 432 |
+
|
| 433 |
+
if __name__ == "__main__":
|
| 434 |
+
app_demo = build_interface()
|
| 435 |
+
print("Starting Multi-User Gradio App...")
|
| 436 |
+
app_demo.launch()
|