File size: 12,846 Bytes
6bd22b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad95ef1
6bd22b5
 
ad95ef1
 
 
6bd22b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad95ef1
6bd22b5
 
 
ad95ef1
6bd22b5
 
 
ad95ef1
6bd22b5
ad95ef1
6bd22b5
 
 
ad95ef1
 
 
6bd22b5
 
 
 
 
 
 
 
 
 
 
ad95ef1
6bd22b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad95ef1
 
 
 
 
 
 
6bd22b5
ad95ef1
 
 
6bd22b5
ad95ef1
6bd22b5
ad95ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
6bd22b5
ad95ef1
6bd22b5
ad95ef1
 
 
 
 
 
 
6bd22b5
 
69c5213
ad95ef1
69c5213
ad95ef1
69c5213
 
 
 
 
 
 
 
ad95ef1
 
 
6bd22b5
 
ad95ef1
6bd22b5
 
ad95ef1
6bd22b5
 
 
ad95ef1
6bd22b5
 
 
 
 
 
 
 
 
 
 
ad95ef1
 
 
6bd22b5
 
69c5213
 
 
 
 
 
 
6bd22b5
 
69c5213
6bd22b5
69c5213
6bd22b5
 
 
 
69c5213
6bd22b5
69c5213
6bd22b5
69c5213
 
 
6bd22b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a33ddf
 
 
 
6bd22b5
 
7a33ddf
 
 
6bd22b5
 
 
 
 
 
 
 
ad95ef1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import os
import shutil
import time
import csv
import uuid
from itertools import cycle
from typing import List, Tuple, Optional
from datetime import datetime
import gradio as gr # Needed for gr.update, gr.Warning, gr.Info, gr.Error

from .data_fetcher import read_hacker_news_rss, format_published_time
from .model_trainer import (
    authenticate_hf,
    train_with_dataset,
    get_top_hits,
    load_embedding_model,
    upload_model_to_hub
)
from .config import AppConfig
from .vibe_logic import VibeChecker
from sentence_transformers import SentenceTransformer

class HackerNewsFineTuner:
    """
    Encapsulates all application logic and state for a single user session.
    """

    def __init__(self, config: AppConfig = AppConfig):
        # --- Dependencies ---
        self.config = config
        
        # --- Session Identification ---
        self.session_id = str(uuid.uuid4())
        
        # Define session-specific paths to allow simultaneous training
        self.session_root = self.config.ARTIFACTS_DIR / self.session_id
        self.output_dir = self.session_root / "embedding_gemma_finetuned"
        self.dataset_export_file = self.session_root / "training_dataset.csv"
        
        # Setup directories
        os.makedirs(self.output_dir, exist_ok=True)
        print(f"[{self.session_id}] New session started. Artifacts: {self.session_root}")

        # --- Application State ---
        self.model: Optional[SentenceTransformer] = None
        self.vibe_checker: Optional[VibeChecker] = None
        self.titles: List[str] = [] 
        self.last_hn_dataset: List[List[str]] = [] 
        self.imported_dataset: List[List[str]] = [] 

        # Authenticate once (global)
        authenticate_hf(self.config.HF_TOKEN)

    def _update_vibe_checker(self):
        """Initializes or updates the VibeChecker with the current model state."""
        if self.model:
            self.vibe_checker = VibeChecker(
                model=self.model,
                query_anchor=self.config.QUERY_ANCHOR,
                task_name=self.config.TASK_NAME
            )
        else:
            self.vibe_checker = None

    ## Data and Model Management ##

    def refresh_data_and_model(self) -> Tuple[List[str], str]:
        """
        Reloads model and fetches data.
        Returns:
            - List of titles (for the UI)
            - Status message string
        """
        print(f"[{self.session_id}] Reloading model and data...")

        self.last_hn_dataset = []
        self.imported_dataset = []

        # 1. Reload the base embedding model
        try:
            self.model = load_embedding_model(self.config.MODEL_NAME)
            self._update_vibe_checker()
        except Exception as e:
            error_msg = f"CRITICAL ERROR: Model failed to load. {e}"
            print(error_msg)
            self.model = None
            self._update_vibe_checker()
            return [], error_msg

        # 2. Fetch fresh news data
        news_feed, status_msg = read_hacker_news_rss(self.config)
        titles_out = []
        status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"

        if news_feed is not None and news_feed.entries:
            titles_out = [item.title for item in news_feed.entries]
        else:
            titles_out = ["Error fetching news."]
            gr.Warning(f"Data reload failed. {status_msg}")

        self.titles = titles_out

        # Return raw list of titles + status text
        return self.titles, status_value

    # --- Import Dataset/Export ---
    def import_additional_dataset(self, file_path: str) -> str:
        if not file_path:
            return "Please upload a CSV file."
        new_dataset, num_imported = [], 0
        try:
            with open(file_path, 'r', newline='', encoding='utf-8') as f:
                reader = csv.reader(f)
                try:
                    header = next(reader)
                    # Simple heuristic to detect if header exists
                    if not (header and header[0].lower().strip() == 'anchor'):
                        f.seek(0)
                except StopIteration:
                    return "Error: Uploaded file is empty."

                for row in reader:
                    if len(row) == 3:
                        new_dataset.append([s.strip() for s in row])
                        num_imported += 1
            if num_imported == 0:
                raise ValueError("No valid rows found.")
            self.imported_dataset = new_dataset
            return f"Imported {num_imported} triplets."
        except Exception as e:
            return f"Import failed: {e}"

    def export_dataset(self) -> Optional[str]:
        if not self.last_hn_dataset:
            gr.Warning("No dataset generated yet.")
            return None
        
        file_path = self.dataset_export_file
        try:
            with open(file_path, 'w', newline='', encoding='utf-8') as f:
                writer = csv.writer(f)
                writer.writerow(['Anchor', 'Positive', 'Negative'])
                writer.writerows(self.last_hn_dataset)
            gr.Info(f"Dataset exported.")
            return str(file_path)
        except Exception as e:
            gr.Error(f"Export failed: {e}")
            return None

    def download_model(self) -> Optional[str]:
        if not os.path.exists(self.output_dir):
            gr.Warning("No model trained yet.")
            return None
        
        timestamp = int(time.time())
        try:
            base_name = self.session_root / f"model_finetuned_{timestamp}"
            archive_path = shutil.make_archive(
                base_name=str(base_name),
                format='zip',
                root_dir=self.output_dir,
            )
            gr.Info(f"Model zipped.")
            return archive_path
        except Exception as e:
            gr.Error(f"Zip failed: {e}")
            return None

    def upload_model(self, repo_name: str, oauth_token_str: str) -> str:
        """
        Calls the model trainer upload function using the session's output directory.
        """
        if not os.path.exists(self.output_dir):
            return "❌ Error: No trained model found in this session. Run training first."
        if not repo_name.strip():
            return "❌ Error: Please specify a repository name."
            
        return upload_model_to_hub(self.output_dir, repo_name, oauth_token_str)


    ## Training Logic ##
    def _create_hn_dataset(self, pos_ids: List[int], neg_ids: List[int]) -> List[List[str]]:
        """
        Creates triplets (Anchor, Positive, Negative) from the selected indices.
        Uses cycling to balance the dataset if the number of positives != negatives.
        """
        if not pos_ids or not neg_ids:
            return []

        # Convert indices to actual title strings
        pos_titles = [self.titles[i] for i in pos_ids]
        neg_titles = [self.titles[i] for i in neg_ids]

        dataset = []

        # We need to pair every Positive with a Negative.
        # Strategy: Iterate over the longer list and cycle through the shorter list
        # to ensure every selected item is used at least once and the dataset is balanced.
        
        if len(pos_titles) >= len(neg_titles):
            # More positives than negatives: Iterate positives, reuse negatives
            neg_cycle = cycle(neg_titles)
            for p_title in pos_titles:
                dataset.append([self.config.QUERY_ANCHOR, p_title, next(neg_cycle)])
        else:
            # More negatives than positives: Iterate negatives, reuse positives
            pos_cycle = cycle(pos_titles)
            for n_title in neg_titles:
                dataset.append([self.config.QUERY_ANCHOR, next(pos_cycle), n_title])

        return dataset

    def training(self, pos_ids: List[int], neg_ids: List[int]) -> str:
        """
        Main training entry point.
        Args:
            pos_ids: Indices of stories marked as "Favorite"
            neg_ids: Indices of stories marked as "Dislike"
        """
        if self.model is None:
             raise gr.Error("Model not loaded.")
             
        if self.imported_dataset:
            self.last_hn_dataset = self.imported_dataset
        else:
            # Validation
            if not pos_ids:
                raise gr.Error("Please select at least one 'Favorite' story.")
            if not neg_ids:
                raise gr.Error("Please select at least one 'Dislike' story.")
        
            # Generate Dataset
            self.last_hn_dataset = self._create_hn_dataset(pos_ids, neg_ids)
                    
        if not self.last_hn_dataset:
            raise gr.Error("Dataset generation failed (Empty dataset).")

        def semantic_search_fn() -> str:
            return get_top_hits(model=self.model, target_titles=self.titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)

        result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
        print(f"[{self.session_id}] Starting Training with {len(self.last_hn_dataset)} examples...")
        
        train_with_dataset(
            model=self.model, 
            dataset=self.last_hn_dataset, 
            output_dir=self.output_dir, 
            task_name=self.config.TASK_NAME, 
            search_fn=semantic_search_fn
        )
        
        self._update_vibe_checker()
        print(f"[{self.session_id}] Training Complete.")

        result += "### Search (After):\n" + f"{semantic_search_fn()}"
        return result

    def is_model_tuned(self) -> bool:
        return True if self.last_hn_dataset else False

    ## Vibe Check Logic ##
    def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
        model_name = "<unsaved>"
        if self.last_hn_dataset:
            model_name = f"./{self.output_dir}"

        info_text = (f"**Session:** {self.session_id[:6]}<br>"
                     f"**Base Model:** `{self.config.MODEL_NAME}`<br>"
                     f"**Tuned Model:** `{model_name}`")

        if not self.vibe_checker:
            return "N/A", "Model Loading...", gr.update(value=self._generate_vibe_css("gray")), info_text
        if not news_text or len(news_text.split()) < 3:
            return "N/A", "Text too short", gr.update(value=self._generate_vibe_css("gray")), info_text

        try:
            vibe_result = self.vibe_checker.check(news_text)
            status = vibe_result.status_html.split('>')[1].split('<')[0]
            return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_css(vibe_result.color_hsl)), info_text
        except Exception as e:
            return "N/A", f"Error: {e}", gr.update(value=self._generate_vibe_css("gray")), info_text

    def _generate_vibe_css(self, color: str) -> str:
        """Generates a style block to update the Mood Lamp textbox background."""
        return f"<style>#mood_lamp input {{ background-color: {color} !important; transition: background-color 0.5s ease; }}</style>"

    ## Mood Reader Logic ##
    def fetch_and_display_mood_feed(self) -> str:
        if not self.vibe_checker:
            return "Model not ready. Please wait or reload."
        
        feed, status = read_hacker_news_rss(self.config)
        if not feed or not feed.entries:
            return f"**Feed Error:** {status}"

        scored_entries = []
        for entry in feed.entries:
            title = entry.get('title')
            if not title: continue
            
            vibe_result = self.vibe_checker.check(title)
            scored_entries.append({
                "title": title,
                "link": entry.get('link', '#'),
                "comments": entry.get('comments', '#'),
                "published": format_published_time(entry.published_parsed),
                "mood": vibe_result
            })

        scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True)

        model_name = "<unsaved>"
        if self.last_hn_dataset:
            model_name = f"./{self.output_dir}"

        md = (f"## Hacker News Top Stories\n"
              f"**Session:** {self.session_id[:6]}<br>"
              f"**Base Model:** `{self.config.MODEL_NAME}`<br>"
              f"**Tuned Model:** `{model_name}`<br>"
              f"**Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
              "| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n")
        
        for item in scored_entries:
            md += (f"| {item['mood'].status_html} "
                   f"| {item['mood'].raw_score:.4f} "
                   f"| [{item['title']}]({item['link']}) "
                   f"| [Comments]({item['comments']}) "
                   f"| {item['published']} |\n")
        return md