Upload folder using huggingface_hub
Browse files- example_training_dataset.csv +4 -0
- src/config.py +0 -3
- src/session_manager.py +70 -55
- src/ui.py +282 -108
example_training_dataset.csv
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Anchor,Positive,Negative
|
| 2 |
+
MY_FAVORITE_NEWS,Denial of service and source code exposure in React Server Components,An SVG is all you need
|
| 3 |
+
MY_FAVORITE_NEWS,The highest quality codebase,Litestream VFS
|
| 4 |
+
|
src/config.py
CHANGED
|
@@ -45,9 +45,6 @@ class AppConfig:
|
|
| 45 |
# Anchor text used for contrastive learning dataset generation
|
| 46 |
QUERY_ANCHOR: Final[str] = "MY_FAVORITE_NEWS"
|
| 47 |
|
| 48 |
-
# Number of titles shown for user selection in the Gradio interface
|
| 49 |
-
TOP_TITLES_COUNT: Final[int] = 10
|
| 50 |
-
|
| 51 |
# Default export path for the dataset CSV
|
| 52 |
DATASET_EXPORT_FILENAME: Final[Path] = ARTIFACTS_DIR.joinpath("training_dataset.csv")
|
| 53 |
|
|
|
|
| 45 |
# Anchor text used for contrastive learning dataset generation
|
| 46 |
QUERY_ANCHOR: Final[str] = "MY_FAVORITE_NEWS"
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
# Default export path for the dataset CSV
|
| 49 |
DATASET_EXPORT_FILENAME: Final[Path] = ARTIFACTS_DIR.joinpath("training_dataset.csv")
|
| 50 |
|
src/session_manager.py
CHANGED
|
@@ -45,8 +45,6 @@ class HackerNewsFineTuner:
|
|
| 45 |
self.model: Optional[SentenceTransformer] = None
|
| 46 |
self.vibe_checker: Optional[VibeChecker] = None
|
| 47 |
self.titles: List[str] = []
|
| 48 |
-
self.target_titles: List[str] = []
|
| 49 |
-
self.number_list: List[int] = []
|
| 50 |
self.last_hn_dataset: List[List[str]] = []
|
| 51 |
self.imported_dataset: List[List[str]] = []
|
| 52 |
|
|
@@ -66,9 +64,12 @@ class HackerNewsFineTuner:
|
|
| 66 |
|
| 67 |
## Data and Model Management ##
|
| 68 |
|
| 69 |
-
def refresh_data_and_model(self) -> Tuple[
|
| 70 |
"""
|
| 71 |
Reloads model and fetches data.
|
|
|
|
|
|
|
|
|
|
| 72 |
"""
|
| 73 |
print(f"[{self.session_id}] Reloading model and data...")
|
| 74 |
|
|
@@ -84,34 +85,23 @@ class HackerNewsFineTuner:
|
|
| 84 |
print(error_msg)
|
| 85 |
self.model = None
|
| 86 |
self._update_vibe_checker()
|
| 87 |
-
return
|
| 88 |
-
gr.update(choices=[], label="Model Load Failed"),
|
| 89 |
-
gr.update(value=error_msg)
|
| 90 |
-
)
|
| 91 |
|
| 92 |
# 2. Fetch fresh news data
|
| 93 |
news_feed, status_msg = read_hacker_news_rss(self.config)
|
| 94 |
-
titles_out
|
| 95 |
status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"
|
| 96 |
|
| 97 |
if news_feed is not None and news_feed.entries:
|
| 98 |
-
titles_out = [item.title for item in news_feed.entries
|
| 99 |
-
target_titles_out = [item.title for item in news_feed.entries[self.config.TOP_TITLES_COUNT:]]
|
| 100 |
else:
|
| 101 |
-
titles_out = ["Error fetching news."
|
| 102 |
gr.Warning(f"Data reload failed. {status_msg}")
|
| 103 |
|
| 104 |
self.titles = titles_out
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
return (
|
| 109 |
-
gr.update(
|
| 110 |
-
choices=self.titles,
|
| 111 |
-
label=f"Hacker News Top {len(self.titles)} (Select your favorites)"
|
| 112 |
-
),
|
| 113 |
-
gr.update(value=status_value)
|
| 114 |
-
)
|
| 115 |
|
| 116 |
# --- Import Dataset/Export ---
|
| 117 |
def import_additional_dataset(self, file_path: str) -> str:
|
|
@@ -123,6 +113,7 @@ class HackerNewsFineTuner:
|
|
| 123 |
reader = csv.reader(f)
|
| 124 |
try:
|
| 125 |
header = next(reader)
|
|
|
|
| 126 |
if not (header and header[0].lower().strip() == 'anchor'):
|
| 127 |
f.seek(0)
|
| 128 |
except StopIteration:
|
|
@@ -188,54 +179,75 @@ class HackerNewsFineTuner:
|
|
| 188 |
|
| 189 |
|
| 190 |
## Training Logic ##
|
| 191 |
-
def _create_hn_dataset(self,
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
def get_titles(anchor_id, pool_id):
|
| 199 |
-
return (self.titles[pool_id], self.titles[anchor_id]) if is_minority else (self.titles[anchor_id], self.titles[pool_id])
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
|
|
|
| 203 |
|
| 204 |
-
|
| 205 |
-
non_fav_idx = list(anchor_ids)[0] if is_minority else pool_ids[0]
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
return
|
| 214 |
|
| 215 |
-
def training(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
if self.model is None:
|
| 217 |
raise gr.Error("Model not loaded.")
|
| 218 |
-
if not selected_ids:
|
| 219 |
-
raise gr.Error("Select at least one title.")
|
| 220 |
-
if len(selected_ids) == len(self.number_list):
|
| 221 |
-
raise gr.Error("Cannot select all titles.")
|
| 222 |
-
|
| 223 |
-
hn_dataset, _, _ = self._create_hn_dataset(selected_ids)
|
| 224 |
-
self.last_hn_dataset = hn_dataset
|
| 225 |
-
final_dataset = self.last_hn_dataset + self.imported_dataset
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
def semantic_search_fn() -> str:
|
| 231 |
-
return get_top_hits(model=self.model, target_titles=self.
|
| 232 |
|
| 233 |
result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
|
| 234 |
-
print(f"[{self.session_id}] Starting Training...")
|
| 235 |
|
| 236 |
train_with_dataset(
|
| 237 |
model=self.model,
|
| 238 |
-
dataset=
|
| 239 |
output_dir=self.output_dir,
|
| 240 |
task_name=self.config.TASK_NAME,
|
| 241 |
search_fn=semantic_search_fn
|
|
@@ -247,6 +259,9 @@ class HackerNewsFineTuner:
|
|
| 247 |
result += "### Search (After):\n" + f"{semantic_search_fn()}"
|
| 248 |
return result
|
| 249 |
|
|
|
|
|
|
|
|
|
|
| 250 |
## Vibe Check Logic ##
|
| 251 |
def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
|
| 252 |
info_text = f"**Session:** {self.session_id[:6]}<br>**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}"
|
|
@@ -303,4 +318,4 @@ class HackerNewsFineTuner:
|
|
| 303 |
f"| [{item['title']}]({item['link']}) "
|
| 304 |
f"| [Comments]({item['comments']}) "
|
| 305 |
f"| {item['published']} |\n")
|
| 306 |
-
return md
|
|
|
|
| 45 |
self.model: Optional[SentenceTransformer] = None
|
| 46 |
self.vibe_checker: Optional[VibeChecker] = None
|
| 47 |
self.titles: List[str] = []
|
|
|
|
|
|
|
| 48 |
self.last_hn_dataset: List[List[str]] = []
|
| 49 |
self.imported_dataset: List[List[str]] = []
|
| 50 |
|
|
|
|
| 64 |
|
| 65 |
## Data and Model Management ##
|
| 66 |
|
| 67 |
+
def refresh_data_and_model(self) -> Tuple[List[str], str]:
|
| 68 |
"""
|
| 69 |
Reloads model and fetches data.
|
| 70 |
+
Returns:
|
| 71 |
+
- List of titles (for the UI)
|
| 72 |
+
- Status message string
|
| 73 |
"""
|
| 74 |
print(f"[{self.session_id}] Reloading model and data...")
|
| 75 |
|
|
|
|
| 85 |
print(error_msg)
|
| 86 |
self.model = None
|
| 87 |
self._update_vibe_checker()
|
| 88 |
+
return [], error_msg
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
# 2. Fetch fresh news data
|
| 91 |
news_feed, status_msg = read_hacker_news_rss(self.config)
|
| 92 |
+
titles_out = []
|
| 93 |
status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"
|
| 94 |
|
| 95 |
if news_feed is not None and news_feed.entries:
|
| 96 |
+
titles_out = [item.title for item in news_feed.entries]
|
|
|
|
| 97 |
else:
|
| 98 |
+
titles_out = ["Error fetching news."]
|
| 99 |
gr.Warning(f"Data reload failed. {status_msg}")
|
| 100 |
|
| 101 |
self.titles = titles_out
|
| 102 |
+
|
| 103 |
+
# Return raw list of titles + status text
|
| 104 |
+
return self.titles, status_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# --- Import Dataset/Export ---
|
| 107 |
def import_additional_dataset(self, file_path: str) -> str:
|
|
|
|
| 113 |
reader = csv.reader(f)
|
| 114 |
try:
|
| 115 |
header = next(reader)
|
| 116 |
+
# Simple heuristic to detect if header exists
|
| 117 |
if not (header and header[0].lower().strip() == 'anchor'):
|
| 118 |
f.seek(0)
|
| 119 |
except StopIteration:
|
|
|
|
| 179 |
|
| 180 |
|
| 181 |
## Training Logic ##
|
| 182 |
+
def _create_hn_dataset(self, pos_ids: List[int], neg_ids: List[int]) -> List[List[str]]:
|
| 183 |
+
"""
|
| 184 |
+
Creates triplets (Anchor, Positive, Negative) from the selected indices.
|
| 185 |
+
Uses cycling to balance the dataset if the number of positives != negatives.
|
| 186 |
+
"""
|
| 187 |
+
if not pos_ids or not neg_ids:
|
| 188 |
+
return []
|
|
|
|
|
|
|
| 189 |
|
| 190 |
+
# Convert indices to actual title strings
|
| 191 |
+
pos_titles = [self.titles[i] for i in pos_ids]
|
| 192 |
+
neg_titles = [self.titles[i] for i in neg_ids]
|
| 193 |
|
| 194 |
+
dataset = []
|
|
|
|
| 195 |
|
| 196 |
+
# We need to pair every Positive with a Negative.
|
| 197 |
+
# Strategy: Iterate over the longer list and cycle through the shorter list
|
| 198 |
+
# to ensure every selected item is used at least once and the dataset is balanced.
|
| 199 |
+
|
| 200 |
+
if len(pos_titles) >= len(neg_titles):
|
| 201 |
+
# More positives than negatives: Iterate positives, reuse negatives
|
| 202 |
+
neg_cycle = cycle(neg_titles)
|
| 203 |
+
for p_title in pos_titles:
|
| 204 |
+
dataset.append([self.config.QUERY_ANCHOR, p_title, next(neg_cycle)])
|
| 205 |
+
else:
|
| 206 |
+
# More negatives than positives: Iterate negatives, reuse positives
|
| 207 |
+
pos_cycle = cycle(pos_titles)
|
| 208 |
+
for n_title in neg_titles:
|
| 209 |
+
dataset.append([self.config.QUERY_ANCHOR, next(pos_cycle), n_title])
|
| 210 |
|
| 211 |
+
return dataset
|
| 212 |
|
| 213 |
+
def training(self, pos_ids: List[int], neg_ids: List[int]) -> str:
|
| 214 |
+
"""
|
| 215 |
+
Main training entry point.
|
| 216 |
+
Args:
|
| 217 |
+
pos_ids: Indices of stories marked as "Favorite"
|
| 218 |
+
neg_ids: Indices of stories marked as "Dislike"
|
| 219 |
+
"""
|
| 220 |
if self.model is None:
|
| 221 |
raise gr.Error("Model not loaded.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
+
# Validation
|
| 224 |
+
if not pos_ids:
|
| 225 |
+
raise gr.Error("Please select at least one 'Favorite' story.")
|
| 226 |
+
if not neg_ids:
|
| 227 |
+
raise gr.Error("Please select at least one 'Dislike' story.")
|
| 228 |
+
|
| 229 |
+
# Generate Dataset
|
| 230 |
+
hn_dataset = self._create_hn_dataset(pos_ids, neg_ids)
|
| 231 |
+
|
| 232 |
+
# Merge with imported dataset if it exists
|
| 233 |
+
if self.imported_dataset:
|
| 234 |
+
# If we have both, combine them
|
| 235 |
+
self.last_hn_dataset = hn_dataset + self.imported_dataset
|
| 236 |
+
else:
|
| 237 |
+
self.last_hn_dataset = hn_dataset
|
| 238 |
+
|
| 239 |
+
if not self.last_hn_dataset:
|
| 240 |
+
raise gr.Error("Dataset generation failed (Empty dataset).")
|
| 241 |
|
| 242 |
def semantic_search_fn() -> str:
|
| 243 |
+
return get_top_hits(model=self.model, target_titles=self.titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
|
| 244 |
|
| 245 |
result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
|
| 246 |
+
print(f"[{self.session_id}] Starting Training with {len(self.last_hn_dataset)} examples...")
|
| 247 |
|
| 248 |
train_with_dataset(
|
| 249 |
model=self.model,
|
| 250 |
+
dataset=self.last_hn_dataset,
|
| 251 |
output_dir=self.output_dir,
|
| 252 |
task_name=self.config.TASK_NAME,
|
| 253 |
search_fn=semantic_search_fn
|
|
|
|
| 259 |
result += "### Search (After):\n" + f"{semantic_search_fn()}"
|
| 260 |
return result
|
| 261 |
|
| 262 |
+
def is_model_tuned(self) -> bool:
|
| 263 |
+
return True if self.last_hn_dataset else False
|
| 264 |
+
|
| 265 |
## Vibe Check Logic ##
|
| 266 |
def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
|
| 267 |
info_text = f"**Session:** {self.session_id[:6]}<br>**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}"
|
|
|
|
| 318 |
f"| [{item['title']}]({item['link']}) "
|
| 319 |
f"| [Comments]({item['comments']}) "
|
| 320 |
f"| {item['published']} |\n")
|
| 321 |
+
return md
|
src/ui.py
CHANGED
|
@@ -1,26 +1,64 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from typing import Optional
|
| 3 |
from datetime import datetime
|
| 4 |
|
| 5 |
from .config import AppConfig
|
| 6 |
from .session_manager import HackerNewsFineTuner
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# --- Session Wrappers ---
|
| 9 |
|
| 10 |
def refresh_wrapper(app):
|
| 11 |
"""
|
| 12 |
Initializes the session if it's not already created, then runs the refresh.
|
| 13 |
-
Returns
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
if app is None or callable(app) or isinstance(app, type):
|
| 16 |
print("Initializing new HackerNewsFineTuner session...")
|
| 17 |
app = HackerNewsFineTuner(AppConfig)
|
| 18 |
|
| 19 |
# Run the refresh logic
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
def import_wrapper(app, file):
|
| 26 |
return app.import_additional_dataset(file)
|
|
@@ -32,19 +70,31 @@ def download_model_wrapper(app):
|
|
| 32 |
return app.download_model()
|
| 33 |
|
| 34 |
def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
|
| 35 |
-
"""
|
| 36 |
-
Wrapper for pushing the model to the Hugging Face Hub.
|
| 37 |
-
Gradio automatically injects 'oauth_token' if the user is logged in via LoginButton.
|
| 38 |
-
"""
|
| 39 |
if oauth_token is None:
|
| 40 |
return "⚠️ You must be logged in to push to the Hub. Please sign in above."
|
| 41 |
-
|
| 42 |
-
# Extract the token string from the OAuthToken object
|
| 43 |
token_str = oauth_token.token
|
| 44 |
return app.upload_model(repo_name, token_str)
|
| 45 |
|
| 46 |
-
def training_wrapper(app,
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
def vibe_check_wrapper(app, text):
|
| 50 |
return app.get_vibe_check(text)
|
|
@@ -57,124 +107,248 @@ def mood_feed_wrapper(app):
|
|
| 57 |
|
| 58 |
def build_interface() -> gr.Blocks:
|
| 59 |
with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
|
| 60 |
-
# Initialize state as None. It will be populated by refresh_wrapper on load.
|
| 61 |
session_state = gr.State()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
with gr.Column():
|
| 64 |
gr.Markdown("# 🤖 EmbeddingGemma Modkit: Fine-Tuning and Mood Reader")
|
| 65 |
gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/README.md) for more details.")
|
| 66 |
-
gr.LoginButton(value="(Optional) Sign in to Hugging Face, if you want to push fine-tuned model to your repo.")
|
| 67 |
|
| 68 |
with gr.Tab("🚀 Fine-Tuning & Evaluation"):
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
with gr.Row():
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
-
gr.
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
with gr.Row():
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
download_status = gr.Markdown("Ready.")
|
| 88 |
|
| 89 |
-
|
| 90 |
-
dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
|
| 91 |
-
model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
| 94 |
with gr.Row():
|
| 95 |
-
repo_name_input = gr.Textbox(label="Target Repository Name", placeholder="e.g., my-news-vibe
|
| 96 |
-
push_to_hub_btn = gr.Button("
|
|
|
|
|
|
|
| 97 |
|
| 98 |
push_status = gr.Markdown("")
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
)
|
| 108 |
|
| 109 |
-
|
| 110 |
-
clear_reload_btn,
|
| 111 |
-
run_training_btn,
|
| 112 |
-
download_dataset_btn,
|
| 113 |
-
download_model_btn,
|
| 114 |
-
push_to_hub_btn
|
| 115 |
-
]
|
| 116 |
-
|
| 117 |
-
# 2. Buttons
|
| 118 |
-
clear_reload_btn.click(
|
| 119 |
-
fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
|
| 120 |
-
outputs=buttons_to_lock
|
| 121 |
-
).then(
|
| 122 |
-
fn=refresh_wrapper,
|
| 123 |
-
inputs=[session_state],
|
| 124 |
-
outputs=[session_state, favorite_list, output]
|
| 125 |
-
).then(
|
| 126 |
-
fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
|
| 127 |
-
outputs=buttons_to_lock
|
| 128 |
-
)
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
).then(
|
| 134 |
-
fn=training_wrapper,
|
| 135 |
-
inputs=[session_state, favorite_list],
|
| 136 |
-
outputs=[output]
|
| 137 |
-
).then(
|
| 138 |
-
fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
|
| 139 |
-
outputs=buttons_to_lock
|
| 140 |
-
)
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
)
|
| 178 |
|
| 179 |
with gr.Tab("📰 Hacker News Similarity Check"):
|
| 180 |
with gr.Column():
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from typing import Optional, Dict, List
|
| 3 |
from datetime import datetime
|
| 4 |
|
| 5 |
from .config import AppConfig
|
| 6 |
from .session_manager import HackerNewsFineTuner
|
| 7 |
|
| 8 |
+
# --- Constants for Labels ---
|
| 9 |
+
LABEL_FAV = "👍"
|
| 10 |
+
LABEL_NEU = "😐"
|
| 11 |
+
LABEL_DIS = "👎"
|
| 12 |
+
|
| 13 |
# --- Session Wrappers ---
|
| 14 |
|
| 15 |
def refresh_wrapper(app):
|
| 16 |
"""
|
| 17 |
Initializes the session if it's not already created, then runs the refresh.
|
| 18 |
+
Returns:
|
| 19 |
+
1. App instance
|
| 20 |
+
2. Stories List (List[str])
|
| 21 |
+
3. Empty Labels Dict (Dict)
|
| 22 |
+
4. Log text
|
| 23 |
"""
|
| 24 |
if app is None or callable(app) or isinstance(app, type):
|
| 25 |
print("Initializing new HackerNewsFineTuner session...")
|
| 26 |
app = HackerNewsFineTuner(AppConfig)
|
| 27 |
|
| 28 |
# Run the refresh logic
|
| 29 |
+
# choices_list is a simple list of strings: ["Title 1", "Title 2", ...]
|
| 30 |
+
choices_list, log_update = app.refresh_data_and_model()
|
| 31 |
+
|
| 32 |
+
# Reset user labels
|
| 33 |
+
empty_labels = {}
|
| 34 |
+
|
| 35 |
+
return app, choices_list, empty_labels, log_update
|
| 36 |
+
|
| 37 |
+
def on_app_load(app, profile: Optional[gr.OAuthProfile] = None):
|
| 38 |
+
"""
|
| 39 |
+
Combined wrapper for initial load:
|
| 40 |
+
1. Initializes/Refreshes App Session
|
| 41 |
+
2. Checks OAuth Profile to enable/disable Hub features
|
| 42 |
+
"""
|
| 43 |
+
# 1. Reuse the logic from refresh_wrapper
|
| 44 |
+
app, stories, labels, text_update = refresh_wrapper(app)
|
| 45 |
|
| 46 |
+
# 2. Check Login Status
|
| 47 |
+
is_logged_in = profile is not None
|
| 48 |
+
username = profile.username if is_logged_in else None
|
| 49 |
+
|
| 50 |
+
hub_interactive = gr.update(interactive=is_logged_in)
|
| 51 |
+
|
| 52 |
+
# Return items matching the output signature of demo.load
|
| 53 |
+
return app, stories, labels, text_update, hub_interactive, hub_interactive, username
|
| 54 |
+
|
| 55 |
+
def update_repo_preview(username, repo_name):
|
| 56 |
+
"""Updates the markdown preview to show 'username/repo_name'."""
|
| 57 |
+
if not username:
|
| 58 |
+
return "⚠️ Sign in to see the target repository path."
|
| 59 |
+
|
| 60 |
+
clean_repo = repo_name.strip() if repo_name else "..."
|
| 61 |
+
return f"Target Repository: **`{username}/{clean_repo}`**"
|
| 62 |
|
| 63 |
def import_wrapper(app, file):
|
| 64 |
return app.import_additional_dataset(file)
|
|
|
|
| 70 |
return app.download_model()
|
| 71 |
|
| 72 |
def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
if oauth_token is None:
|
| 74 |
return "⚠️ You must be logged in to push to the Hub. Please sign in above."
|
|
|
|
|
|
|
| 75 |
token_str = oauth_token.token
|
| 76 |
return app.upload_model(repo_name, token_str)
|
| 77 |
|
| 78 |
+
def training_wrapper(app, stories: List[str], labels: Dict[int, str]):
|
| 79 |
+
"""
|
| 80 |
+
Parses the Stories and Labels to extract Positive and Negative indices.
|
| 81 |
+
stories: List of titles
|
| 82 |
+
labels: Dictionary of {index: LABEL_FAV | LABEL_DIS | LABEL_NEU}
|
| 83 |
+
"""
|
| 84 |
+
pos_ids = []
|
| 85 |
+
neg_ids = []
|
| 86 |
+
|
| 87 |
+
# Iterate through all available stories by index
|
| 88 |
+
for i in range(len(stories)):
|
| 89 |
+
# Get label for this index, default to Neutral if not set
|
| 90 |
+
label = labels.get(i, LABEL_NEU)
|
| 91 |
+
|
| 92 |
+
if label == LABEL_FAV:
|
| 93 |
+
pos_ids.append(i)
|
| 94 |
+
elif label == LABEL_DIS:
|
| 95 |
+
neg_ids.append(i)
|
| 96 |
+
|
| 97 |
+
return app.training(pos_ids, neg_ids)
|
| 98 |
|
| 99 |
def vibe_check_wrapper(app, text):
|
| 100 |
return app.get_vibe_check(text)
|
|
|
|
| 107 |
|
| 108 |
def build_interface() -> gr.Blocks:
|
| 109 |
with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
|
|
|
|
| 110 |
session_state = gr.State()
|
| 111 |
+
username_state = gr.State()
|
| 112 |
+
|
| 113 |
+
# State variables for the Feed List and User Choices
|
| 114 |
+
stories_state = gr.State([])
|
| 115 |
+
labels_state = gr.State({})
|
| 116 |
+
reset_counter = gr.State(0)
|
| 117 |
|
| 118 |
with gr.Column():
|
| 119 |
gr.Markdown("# 🤖 EmbeddingGemma Modkit: Fine-Tuning and Mood Reader")
|
| 120 |
gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/README.md) for more details.")
|
|
|
|
| 121 |
|
| 122 |
with gr.Tab("🚀 Fine-Tuning & Evaluation"):
|
| 123 |
+
|
| 124 |
+
# --- Model Indicator ---
|
| 125 |
+
gr.Dropdown(
|
| 126 |
+
choices=[f"{AppConfig.MODEL_NAME}"],
|
| 127 |
+
value=f"{AppConfig.MODEL_NAME}",
|
| 128 |
+
label="Base Model for Fine-tuning",
|
| 129 |
+
interactive=False
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# --- Step 0: Login ---
|
| 133 |
+
with gr.Accordion("0️⃣ Step 0: Sign In (Optional)", open=True):
|
| 134 |
+
gr.Markdown("Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (Step 3).")
|
| 135 |
with gr.Row():
|
| 136 |
+
gr.LoginButton(value="Sign in with Hugging Face")
|
| 137 |
+
with gr.Column(scale=3):
|
| 138 |
+
gr.Markdown("")
|
| 139 |
+
|
| 140 |
+
# --- Step 1: Data Selection ---
|
| 141 |
+
with gr.Accordion("1️⃣ Step 1: Select Data Source", open=True):
|
| 142 |
+
gr.Markdown("Select titles from the live Hacker News feed **OR** upload your own CSV dataset to prepare your training data.")
|
| 143 |
|
| 144 |
+
with gr.Column():
|
| 145 |
+
# Option A: Live Feed (Radio List)
|
| 146 |
+
with gr.Accordion("Option A: Live Hacker News Feed", open=True):
|
| 147 |
+
gr.Markdown("Rate the stories below to define your vibe.\n\n**⚠️ Note: You must select at least one Favorite and one Dislike to run training.**")
|
| 148 |
+
|
| 149 |
+
with gr.Row():
|
| 150 |
+
reset_all_btn = gr.Button("Reset Selection ↺", variant="secondary", scale=1)
|
| 151 |
+
with gr.Column(scale=3):
|
| 152 |
+
gr.Markdown("")
|
| 153 |
+
|
| 154 |
+
# Dynamic rendering of the story list
|
| 155 |
+
@gr.render(inputs=[stories_state, reset_counter])
|
| 156 |
+
def render_story_list(stories, _counter):
|
| 157 |
+
if not stories:
|
| 158 |
+
gr.Markdown("*No stories loaded. Click 'Reset Model & Fine-tuning state' to fetch data.*")
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
for i, title in enumerate(stories):
|
| 162 |
+
with gr.Row(variant="compact", elem_id=f"story_row_{i}"):
|
| 163 |
+
# Title
|
| 164 |
+
with gr.Column(scale=3):
|
| 165 |
+
gr.Markdown(f"**{i+1}.** {title}")
|
| 166 |
+
|
| 167 |
+
# Radio Selection
|
| 168 |
+
radio = gr.Radio(
|
| 169 |
+
choices=[LABEL_FAV, LABEL_NEU, LABEL_DIS],
|
| 170 |
+
value=LABEL_NEU,
|
| 171 |
+
show_label=False,
|
| 172 |
+
container=False,
|
| 173 |
+
min_width=80,
|
| 174 |
+
scale=1,
|
| 175 |
+
interactive=True
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Update logic
|
| 179 |
+
def update_label(new_val, current_labels, idx=i):
|
| 180 |
+
current_labels[idx] = new_val
|
| 181 |
+
return current_labels
|
| 182 |
+
|
| 183 |
+
radio.change(
|
| 184 |
+
fn=update_label,
|
| 185 |
+
inputs=[radio, labels_state],
|
| 186 |
+
outputs=[labels_state]
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Option B: Upload
|
| 190 |
+
with gr.Accordion("Option B: Upload Custom Dataset", open=False):
|
| 191 |
+
gr.Markdown("Upload a CSV file with columns (no header required, or header ignored if present): `Anchor`, `Positive`, `Negative`.")
|
| 192 |
+
gr.Markdown("See also: [example_training.dataset.csv](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/example_training.dataset.csv)<br>Example:<br>`MY_FAVORITE_NEWS,Good Title,Bad Title`")
|
| 193 |
+
import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=100)
|
| 194 |
+
|
| 195 |
+
# --- Step 2: Training ---
|
| 196 |
+
with gr.Accordion("2️⃣ Step 2: Run Tuning", open=True):
|
| 197 |
+
gr.Markdown("Fine-tune the model using the data selected or uploaded above.")
|
| 198 |
|
| 199 |
with gr.Row():
|
| 200 |
+
run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1)
|
| 201 |
+
clear_reload_btn = gr.Button("Reset Model & Fine-tuning state", scale=1)
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
output = gr.Textbox(lines=10, label="Training Logs & Search Results", value="Waiting to start...", autoscroll=True)
|
|
|
|
|
|
|
| 204 |
|
| 205 |
+
# --- Step 3: Push to Hub ---
|
| 206 |
+
with gr.Accordion("3️⃣ Step 3: Save to Hugging Face Hub (Optional)", open=False):
|
| 207 |
+
gr.Markdown("Push your fine-tuned model to your personal Hugging Face account.")
|
| 208 |
+
|
| 209 |
with gr.Row():
|
| 210 |
+
repo_name_input = gr.Textbox(label="Target Repository Name", value="my-embeddinggemma-news-vibe", placeholder="e.g., my-embeddinggemma-news-vibe", interactive=False)
|
| 211 |
+
push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
|
| 212 |
+
|
| 213 |
+
repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
|
| 214 |
|
| 215 |
push_status = gr.Markdown("")
|
| 216 |
|
| 217 |
+
# --- Step 4: Downloads ---
|
| 218 |
+
with gr.Accordion("4️⃣ Step 4: Download Artifacts", open=False):
|
| 219 |
+
gr.Markdown("Export your combined dataset or download the fine-tuned model locally.")
|
| 220 |
+
|
| 221 |
+
with gr.Row():
|
| 222 |
+
download_dataset_btn = gr.Button("💾 Export Dataset", interactive=False)
|
| 223 |
+
download_model_btn = gr.Button("⬇️ Download Model ZIP", interactive=False)
|
|
|
|
| 224 |
|
| 225 |
+
download_status = gr.Markdown("Ready.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
+
with gr.Row():
|
| 228 |
+
dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
|
| 229 |
+
model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
+
# --- Interaction Logic ---
|
| 232 |
+
|
| 233 |
+
action_buttons = [
|
| 234 |
+
clear_reload_btn,
|
| 235 |
+
run_training_btn,
|
| 236 |
+
download_dataset_btn,
|
| 237 |
+
download_model_btn
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
def set_interactivity(interactive: bool):
|
| 241 |
+
"""Helper to lock/unlock all main action buttons."""
|
| 242 |
+
return [gr.update(interactive=interactive) for _ in action_buttons]
|
| 243 |
+
|
| 244 |
+
# 1. App Startup
|
| 245 |
+
# ----------------
|
| 246 |
+
demo.load(
|
| 247 |
+
fn=lambda: set_interactivity(False), outputs=action_buttons
|
| 248 |
+
).then(
|
| 249 |
+
fn=on_app_load,
|
| 250 |
+
inputs=[session_state],
|
| 251 |
+
outputs=[session_state, stories_state, labels_state, output, repo_name_input, push_to_hub_btn, username_state]
|
| 252 |
+
).then(
|
| 253 |
+
fn=update_repo_preview,
|
| 254 |
+
inputs=[username_state, repo_name_input],
|
| 255 |
+
outputs=[repo_id_preview]
|
| 256 |
+
).then(
|
| 257 |
+
fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# 2. Reset / Refresh / Clear Selections
|
| 261 |
+
# ----------------
|
| 262 |
+
clear_reload_btn.click(
|
| 263 |
+
fn=lambda: set_interactivity(False), outputs=action_buttons
|
| 264 |
+
).then(
|
| 265 |
+
fn=refresh_wrapper,
|
| 266 |
+
inputs=[session_state],
|
| 267 |
+
outputs=[session_state, stories_state, labels_state, output]
|
| 268 |
+
).then(
|
| 269 |
+
fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Reset Selection Button Logic
|
| 273 |
+
def reset_all_selections(counter):
|
| 274 |
+
# Returns: (incremented counter, empty dict for labels)
|
| 275 |
+
return counter + 1, {}
|
| 276 |
|
| 277 |
+
reset_all_btn.click(
|
| 278 |
+
fn=reset_all_selections,
|
| 279 |
+
inputs=[reset_counter],
|
| 280 |
+
outputs=[reset_counter, labels_state]
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# 3. Import Data
|
| 284 |
+
# ----------------
|
| 285 |
+
import_file.change(
|
| 286 |
+
fn=import_wrapper,
|
| 287 |
+
inputs=[session_state, import_file],
|
| 288 |
+
outputs=[download_status]
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# 4. Run Training
|
| 292 |
+
# ----------------
|
| 293 |
+
run_training_btn.click(
|
| 294 |
+
fn=lambda: set_interactivity(False), outputs=action_buttons
|
| 295 |
+
).then(
|
| 296 |
+
fn=training_wrapper,
|
| 297 |
+
inputs=[session_state, stories_state, labels_state],
|
| 298 |
+
outputs=[output]
|
| 299 |
+
).then(
|
| 300 |
+
# Unlock all buttons (including downloads now that we have a model)
|
| 301 |
+
fn=lambda: set_interactivity(True), outputs=action_buttons
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# 5. Downloads
|
| 305 |
+
# ----------------
|
| 306 |
+
download_dataset_btn.click(
|
| 307 |
+
fn=export_wrapper,
|
| 308 |
+
inputs=[session_state],
|
| 309 |
+
outputs=[dataset_output]
|
| 310 |
+
).then(
|
| 311 |
+
# Just show the file output if it exists
|
| 312 |
+
lambda p: gr.update(visible=True) if p else gr.update(),
|
| 313 |
+
inputs=[dataset_output],
|
| 314 |
+
outputs=[dataset_output]
|
| 315 |
+
)
|
| 316 |
|
| 317 |
+
download_model_btn.click(
|
| 318 |
+
# Lock UI
|
| 319 |
+
fn=lambda: set_interactivity(False), outputs=action_buttons
|
| 320 |
+
).then(
|
| 321 |
+
# Reset previous outputs and show "Zipping..."
|
| 322 |
+
fn=lambda: [gr.update(value=None, visible=False), "⏳ Zipping model..."],
|
| 323 |
+
outputs=[model_output, download_status]
|
| 324 |
+
).then(
|
| 325 |
+
# Generate Zip
|
| 326 |
+
fn=download_model_wrapper,
|
| 327 |
+
inputs=[session_state],
|
| 328 |
+
outputs=[model_output]
|
| 329 |
+
).then(
|
| 330 |
+
# Update UI with result
|
| 331 |
+
fn=lambda p: [gr.update(visible=p is not None, value=p), "✅ ZIP ready." if p else "❌ Zipping failed."],
|
| 332 |
+
inputs=[model_output],
|
| 333 |
+
outputs=[model_output, download_status]
|
| 334 |
+
).then(
|
| 335 |
+
# Unlock UI
|
| 336 |
+
fn=lambda: set_interactivity(True), outputs=action_buttons
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
# 6. Push to Hub
|
| 340 |
+
# ----------------
|
| 341 |
+
repo_name_input.change(
|
| 342 |
+
fn=update_repo_preview,
|
| 343 |
+
inputs=[username_state, repo_name_input],
|
| 344 |
+
outputs=[repo_id_preview]
|
| 345 |
+
)
|
| 346 |
|
| 347 |
+
push_to_hub_btn.click(
|
| 348 |
+
fn=push_to_hub_wrapper,
|
| 349 |
+
inputs=[session_state, repo_name_input],
|
| 350 |
+
outputs=[push_status]
|
| 351 |
+
)
|
|
|
|
| 352 |
|
| 353 |
with gr.Tab("📰 Hacker News Similarity Check"):
|
| 354 |
with gr.Column():
|