File size: 12,846 Bytes
6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 69c5213 ad95ef1 69c5213 ad95ef1 69c5213 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 ad95ef1 6bd22b5 69c5213 6bd22b5 69c5213 6bd22b5 69c5213 6bd22b5 69c5213 6bd22b5 69c5213 6bd22b5 69c5213 6bd22b5 7a33ddf 6bd22b5 7a33ddf 6bd22b5 ad95ef1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
import os
import shutil
import time
import csv
import uuid
from itertools import cycle
from typing import List, Tuple, Optional
from datetime import datetime
import gradio as gr # Needed for gr.update, gr.Warning, gr.Info, gr.Error
from .data_fetcher import read_hacker_news_rss, format_published_time
from .model_trainer import (
authenticate_hf,
train_with_dataset,
get_top_hits,
load_embedding_model,
upload_model_to_hub
)
from .config import AppConfig
from .vibe_logic import VibeChecker
from sentence_transformers import SentenceTransformer
class HackerNewsFineTuner:
"""
Encapsulates all application logic and state for a single user session.
"""
def __init__(self, config: AppConfig = AppConfig):
# --- Dependencies ---
self.config = config
# --- Session Identification ---
self.session_id = str(uuid.uuid4())
# Define session-specific paths to allow simultaneous training
self.session_root = self.config.ARTIFACTS_DIR / self.session_id
self.output_dir = self.session_root / "embedding_gemma_finetuned"
self.dataset_export_file = self.session_root / "training_dataset.csv"
# Setup directories
os.makedirs(self.output_dir, exist_ok=True)
print(f"[{self.session_id}] New session started. Artifacts: {self.session_root}")
# --- Application State ---
self.model: Optional[SentenceTransformer] = None
self.vibe_checker: Optional[VibeChecker] = None
self.titles: List[str] = []
self.last_hn_dataset: List[List[str]] = []
self.imported_dataset: List[List[str]] = []
# Authenticate once (global)
authenticate_hf(self.config.HF_TOKEN)
def _update_vibe_checker(self):
"""Initializes or updates the VibeChecker with the current model state."""
if self.model:
self.vibe_checker = VibeChecker(
model=self.model,
query_anchor=self.config.QUERY_ANCHOR,
task_name=self.config.TASK_NAME
)
else:
self.vibe_checker = None
## Data and Model Management ##
def refresh_data_and_model(self) -> Tuple[List[str], str]:
"""
Reloads model and fetches data.
Returns:
- List of titles (for the UI)
- Status message string
"""
print(f"[{self.session_id}] Reloading model and data...")
self.last_hn_dataset = []
self.imported_dataset = []
# 1. Reload the base embedding model
try:
self.model = load_embedding_model(self.config.MODEL_NAME)
self._update_vibe_checker()
except Exception as e:
error_msg = f"CRITICAL ERROR: Model failed to load. {e}"
print(error_msg)
self.model = None
self._update_vibe_checker()
return [], error_msg
# 2. Fetch fresh news data
news_feed, status_msg = read_hacker_news_rss(self.config)
titles_out = []
status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"
if news_feed is not None and news_feed.entries:
titles_out = [item.title for item in news_feed.entries]
else:
titles_out = ["Error fetching news."]
gr.Warning(f"Data reload failed. {status_msg}")
self.titles = titles_out
# Return raw list of titles + status text
return self.titles, status_value
# --- Import Dataset/Export ---
def import_additional_dataset(self, file_path: str) -> str:
if not file_path:
return "Please upload a CSV file."
new_dataset, num_imported = [], 0
try:
with open(file_path, 'r', newline='', encoding='utf-8') as f:
reader = csv.reader(f)
try:
header = next(reader)
# Simple heuristic to detect if header exists
if not (header and header[0].lower().strip() == 'anchor'):
f.seek(0)
except StopIteration:
return "Error: Uploaded file is empty."
for row in reader:
if len(row) == 3:
new_dataset.append([s.strip() for s in row])
num_imported += 1
if num_imported == 0:
raise ValueError("No valid rows found.")
self.imported_dataset = new_dataset
return f"Imported {num_imported} triplets."
except Exception as e:
return f"Import failed: {e}"
def export_dataset(self) -> Optional[str]:
if not self.last_hn_dataset:
gr.Warning("No dataset generated yet.")
return None
file_path = self.dataset_export_file
try:
with open(file_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(['Anchor', 'Positive', 'Negative'])
writer.writerows(self.last_hn_dataset)
gr.Info(f"Dataset exported.")
return str(file_path)
except Exception as e:
gr.Error(f"Export failed: {e}")
return None
def download_model(self) -> Optional[str]:
if not os.path.exists(self.output_dir):
gr.Warning("No model trained yet.")
return None
timestamp = int(time.time())
try:
base_name = self.session_root / f"model_finetuned_{timestamp}"
archive_path = shutil.make_archive(
base_name=str(base_name),
format='zip',
root_dir=self.output_dir,
)
gr.Info(f"Model zipped.")
return archive_path
except Exception as e:
gr.Error(f"Zip failed: {e}")
return None
def upload_model(self, repo_name: str, oauth_token_str: str) -> str:
"""
Calls the model trainer upload function using the session's output directory.
"""
if not os.path.exists(self.output_dir):
return "❌ Error: No trained model found in this session. Run training first."
if not repo_name.strip():
return "❌ Error: Please specify a repository name."
return upload_model_to_hub(self.output_dir, repo_name, oauth_token_str)
## Training Logic ##
def _create_hn_dataset(self, pos_ids: List[int], neg_ids: List[int]) -> List[List[str]]:
"""
Creates triplets (Anchor, Positive, Negative) from the selected indices.
Uses cycling to balance the dataset if the number of positives != negatives.
"""
if not pos_ids or not neg_ids:
return []
# Convert indices to actual title strings
pos_titles = [self.titles[i] for i in pos_ids]
neg_titles = [self.titles[i] for i in neg_ids]
dataset = []
# We need to pair every Positive with a Negative.
# Strategy: Iterate over the longer list and cycle through the shorter list
# to ensure every selected item is used at least once and the dataset is balanced.
if len(pos_titles) >= len(neg_titles):
# More positives than negatives: Iterate positives, reuse negatives
neg_cycle = cycle(neg_titles)
for p_title in pos_titles:
dataset.append([self.config.QUERY_ANCHOR, p_title, next(neg_cycle)])
else:
# More negatives than positives: Iterate negatives, reuse positives
pos_cycle = cycle(pos_titles)
for n_title in neg_titles:
dataset.append([self.config.QUERY_ANCHOR, next(pos_cycle), n_title])
return dataset
def training(self, pos_ids: List[int], neg_ids: List[int]) -> str:
"""
Main training entry point.
Args:
pos_ids: Indices of stories marked as "Favorite"
neg_ids: Indices of stories marked as "Dislike"
"""
if self.model is None:
raise gr.Error("Model not loaded.")
if self.imported_dataset:
self.last_hn_dataset = self.imported_dataset
else:
# Validation
if not pos_ids:
raise gr.Error("Please select at least one 'Favorite' story.")
if not neg_ids:
raise gr.Error("Please select at least one 'Dislike' story.")
# Generate Dataset
self.last_hn_dataset = self._create_hn_dataset(pos_ids, neg_ids)
if not self.last_hn_dataset:
raise gr.Error("Dataset generation failed (Empty dataset).")
def semantic_search_fn() -> str:
return get_top_hits(model=self.model, target_titles=self.titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
print(f"[{self.session_id}] Starting Training with {len(self.last_hn_dataset)} examples...")
train_with_dataset(
model=self.model,
dataset=self.last_hn_dataset,
output_dir=self.output_dir,
task_name=self.config.TASK_NAME,
search_fn=semantic_search_fn
)
self._update_vibe_checker()
print(f"[{self.session_id}] Training Complete.")
result += "### Search (After):\n" + f"{semantic_search_fn()}"
return result
def is_model_tuned(self) -> bool:
return True if self.last_hn_dataset else False
## Vibe Check Logic ##
def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
model_name = "<unsaved>"
if self.last_hn_dataset:
model_name = f"./{self.output_dir}"
info_text = (f"**Session:** {self.session_id[:6]}<br>"
f"**Base Model:** `{self.config.MODEL_NAME}`<br>"
f"**Tuned Model:** `{model_name}`")
if not self.vibe_checker:
return "N/A", "Model Loading...", gr.update(value=self._generate_vibe_css("gray")), info_text
if not news_text or len(news_text.split()) < 3:
return "N/A", "Text too short", gr.update(value=self._generate_vibe_css("gray")), info_text
try:
vibe_result = self.vibe_checker.check(news_text)
status = vibe_result.status_html.split('>')[1].split('<')[0]
return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_css(vibe_result.color_hsl)), info_text
except Exception as e:
return "N/A", f"Error: {e}", gr.update(value=self._generate_vibe_css("gray")), info_text
def _generate_vibe_css(self, color: str) -> str:
"""Generates a style block to update the Mood Lamp textbox background."""
return f"<style>#mood_lamp input {{ background-color: {color} !important; transition: background-color 0.5s ease; }}</style>"
## Mood Reader Logic ##
def fetch_and_display_mood_feed(self) -> str:
if not self.vibe_checker:
return "Model not ready. Please wait or reload."
feed, status = read_hacker_news_rss(self.config)
if not feed or not feed.entries:
return f"**Feed Error:** {status}"
scored_entries = []
for entry in feed.entries:
title = entry.get('title')
if not title: continue
vibe_result = self.vibe_checker.check(title)
scored_entries.append({
"title": title,
"link": entry.get('link', '#'),
"comments": entry.get('comments', '#'),
"published": format_published_time(entry.published_parsed),
"mood": vibe_result
})
scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True)
model_name = "<unsaved>"
if self.last_hn_dataset:
model_name = f"./{self.output_dir}"
md = (f"## Hacker News Top Stories\n"
f"**Session:** {self.session_id[:6]}<br>"
f"**Base Model:** `{self.config.MODEL_NAME}`<br>"
f"**Tuned Model:** `{model_name}`<br>"
f"**Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
"| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n")
for item in scored_entries:
md += (f"| {item['mood'].status_html} "
f"| {item['mood'].raw_score:.4f} "
f"| [{item['title']}]({item['link']}) "
f"| [Comments]({item['comments']}) "
f"| {item['published']} |\n")
return md
|