Upload folder using huggingface_hub
Browse files- README.md +161 -7
- app.py +364 -0
- cli_mood_reader.py +179 -0
- config.py +56 -0
- data_fetcher.py +112 -0
- flask_app.py +58 -0
- hn_mood_reader.py +71 -0
- model_trainer.py +132 -0
- requirements.txt +9 -0
- templates/error.html +13 -0
- templates/index.html +127 -0
- vibe_logic.py +85 -0
README.md
CHANGED
|
@@ -1,14 +1,168 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
-
short_description: EmbeddingGemma Mod Kit
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Embedding Gemma Modkit
|
| 3 |
+
emoji: 😻
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# 🤖 Embedding Gemma Modkit: Fine-Tuning and Mood Reader
|
| 13 |
+
|
| 14 |
+
This project provides a set of tools to fine-tune a sentence-embedding model to understand your personal taste in Hacker News titles and then use that model to score and rank new articles based on their "vibe."
|
| 15 |
+
|
| 16 |
+
It includes three main applications:
|
| 17 |
+
1. A **Gradio App** for interactive fine-tuning, evaluation, and real-time "vibe checks."
|
| 18 |
+
2. An interactive **Command-Line (CLI) App** for viewing and scrolling through the scored feed directly in your terminal.
|
| 19 |
+
3. A **Flask App** for a simple, deployable web "mood reader" that displays the live HN feed.
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## ✨ Features
|
| 24 |
+
|
| 25 |
+
* **Interactive Fine-Tuning:** Use a Gradio interface to select your favorite Hacker News titles and fine-tune the `google/embeddinggemma-300m` model on your preferences.
|
| 26 |
+
* **Semantic Search Evaluation:** See the immediate impact of your training by comparing semantic search results before and after fine-tuning.
|
| 27 |
+
* **Live "Vibe Check":** Input any news title or text to get a real-time similarity score (its "vibe") against your personalized anchor.
|
| 28 |
+
* **Interactive CLI:** A terminal-based mood reader with color-coded output, scrolling, and live refresh capabilities.
|
| 29 |
+
* **Hacker News Mood Reader:** View the live Hacker News feed with each story scored and color-coded based on the current model's understanding of your taste.
|
| 30 |
+
* **Data & Model Management:** Easily import additional training data, export the generated dataset, and download the fine-tuned model as a ZIP file.
|
| 31 |
+
* **Standalone Flask App:** A lightweight, read-only web app to continuously display the scored HN feed, perfect for simple deployment.
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## 🔧 How It Works
|
| 36 |
+
|
| 37 |
+
The core idea is to measure the "vibe" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, defined in `config.py` as **`MY_FAVORITE_NEWS`**.
|
| 38 |
+
|
| 39 |
+
1. **Embedding:** The `sentence-transformers` library is used to convert news titles and the anchor phrase into high-dimensional vectors (embeddings).
|
| 40 |
+
2. **Scoring:** The cosine similarity (or dot product on normalized embeddings) between a title's embedding and the anchor's embedding is calculated. A higher score means a better "vibe."
|
| 41 |
+
3. **Fine-Tuning:** The Gradio app generates a contrastive learning dataset from your selections.
|
| 42 |
+
* **Positive Pairs:** (`MY_FAVORITE_NEWS`, `[A title you selected]`)
|
| 43 |
+
* **Negative Pairs:** (`MY_FAVORITE_NEWS`, `[A title you did not select]`)
|
| 44 |
+
4. **Training:** The model is trained using `MultipleNegativesRankingLoss`, which fine-tunes it to pull the embeddings of your "favorite" titles closer to the anchor phrase and push the others away.
|
| 45 |
+
|
| 46 |
+
## 🚀 Getting Started
|
| 47 |
+
|
| 48 |
+
### 1. Prerequisites
|
| 49 |
+
* Python 3.12+
|
| 50 |
+
* Git
|
| 51 |
+
|
| 52 |
+
### 2. Installation
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
# Clone the repository
|
| 56 |
+
git clone https://huggingface.co/spaces/bebechien/news-vibe-checker
|
| 57 |
+
cd news-vibe-checker
|
| 58 |
+
|
| 59 |
+
# Create and activate a virtual environment (recommended)
|
| 60 |
+
python -m venv venv
|
| 61 |
+
source venv/bin/activate # On Windows, use `venv\Scripts\activate`
|
| 62 |
+
|
| 63 |
+
# Install the required packages
|
| 64 |
+
pip install -r requirements.txt
|
| 65 |
+
````
|
| 66 |
+
|
| 67 |
+
### 3\. (Optional) Hugging Face Authentication
|
| 68 |
+
|
| 69 |
+
If you plan to use gated models or push your fine-tuned model to the Hugging Face Hub, you need to authenticate.
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
# Set your Hugging Face token as an environment variable
|
| 73 |
+
export HF_TOKEN="your_hf_token_here"
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
-----
|
| 77 |
+
|
| 78 |
+
## 🖥️ Running the Applications
|
| 79 |
+
|
| 80 |
+
You can run any of the three applications depending on your needs.
|
| 81 |
+
|
| 82 |
+
### Option A: Interactive Fine-Tuning (Gradio App)
|
| 83 |
+
|
| 84 |
+
This is the main application for creating and evaluating a personalized model.
|
| 85 |
+
|
| 86 |
+
**▶️ To run:**
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
python app.py
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
Navigate to the local URL provided (e.g., `http://127.0.0.1:7860`).
|
| 93 |
+
|
| 94 |
+
### Option B: Interactive Terminal Viewer (CLI App)
|
| 95 |
+
|
| 96 |
+
This app runs directly in your terminal, allowing you to quickly see and scroll through the scored Hacker News feed.
|
| 97 |
+
|
| 98 |
+
**▶️ To run:**
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
python cli_mood_reader.py
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
**Interactive Controls:**
|
| 105 |
+
|
| 106 |
+
* **[↑|↓]** arrow keys to scroll through the story list.
|
| 107 |
+
* **[SPACE]** to refresh the feed with the latest stories.
|
| 108 |
+
* **[q]** to quit the application.
|
| 109 |
+
|
| 110 |
+
You can also start it with options:
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
# Specify a different model from Hugging Face
|
| 114 |
+
python cli_mood_reader.py --model google/embeddinggemma-300m
|
| 115 |
+
|
| 116 |
+
# Show 10 stories per screen instead of the default 15
|
| 117 |
+
python cli_mood_reader.py --top 10
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### Option C: Standalone Web Viewer (Flask App)
|
| 121 |
+
|
| 122 |
+
This app is a simple, read-only web page that fetches and displays the scored HN feed. It's ideal for deploying a finished model.
|
| 123 |
+
|
| 124 |
+
**▶️ To run:**
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
# (Optional) Specify a model from the Hugging Face Hub
|
| 128 |
+
export MOOD_MODEL="bebechien/embedding-gemma-finetuned-hn"
|
| 129 |
+
|
| 130 |
+
# Run the Flask server
|
| 131 |
+
python flask_app.py
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
Navigate to `http://127.0.0.1:5000` to see the results.
|
| 135 |
+
|
| 136 |
+
-----
|
| 137 |
+
|
| 138 |
+
## ⚙️ Configuration
|
| 139 |
+
|
| 140 |
+
Key parameters can be adjusted in `config.py`:
|
| 141 |
+
|
| 142 |
+
* `MODEL_NAME`: The base model to use for fine-tuning (e.g., `'google/embeddinggemma-300m'`).
|
| 143 |
+
* `QUERY_ANCHOR`: The anchor text used for similarity scoring (e.g., `"MY_FAVORITE_NEWS"`).
|
| 144 |
+
* `DEFAULT_MOOD_READER_MODEL`: The default model used by the Flask and CLI apps.
|
| 145 |
+
* `HN_RSS_URL`: The RSS feed URL.
|
| 146 |
+
* `CACHE_DURATION_SECONDS`: How long to cache the RSS feed data.
|
| 147 |
+
|
| 148 |
+
-----
|
| 149 |
+
|
| 150 |
+
## 📂 File Structure
|
| 151 |
+
|
| 152 |
+
```
|
| 153 |
+
.
|
| 154 |
+
├── app.py # Main Gradio application for fine-tuning
|
| 155 |
+
├── cli_mood_reader.py # Interactive command-line mood reader
|
| 156 |
+
├── flask_app.py # Standalone Flask application for mood reading
|
| 157 |
+
├── hn_mood_reader.py # Core logic for fetching and scoring (used by Flask/CLI)
|
| 158 |
+
├── model_trainer.py # Handles model loading and fine-tuning
|
| 159 |
+
├── vibe_logic.py # Calculates similarity scores and "vibe" status
|
| 160 |
+
├── data_fetcher.py # Fetches and caches the Hacker News RSS feed
|
| 161 |
+
├── config.py # Central configuration for all modules
|
| 162 |
+
├── requirements.txt # Python package dependencies
|
| 163 |
+
├── README.md # This file
|
| 164 |
+
└── templates/ # HTML templates for the Flask app
|
| 165 |
+
├── index.html
|
| 166 |
+
└── error.html
|
| 167 |
+
```
|
| 168 |
+
|
app.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
import time
|
| 5 |
+
import csv
|
| 6 |
+
from itertools import cycle
|
| 7 |
+
from typing import List, Iterable, Tuple, Optional, Callable
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
# Import modules
|
| 11 |
+
from data_fetcher import read_hacker_news_rss, format_published_time
|
| 12 |
+
from model_trainer import (
|
| 13 |
+
authenticate_hf,
|
| 14 |
+
train_with_dataset,
|
| 15 |
+
get_top_hits,
|
| 16 |
+
load_embedding_model
|
| 17 |
+
)
|
| 18 |
+
from config import AppConfig
|
| 19 |
+
from vibe_logic import VibeChecker
|
| 20 |
+
from sentence_transformers import SentenceTransformer
|
| 21 |
+
|
| 22 |
+
# --- Main Application Class ---
|
| 23 |
+
|
| 24 |
+
class HackerNewsFineTuner:
|
| 25 |
+
"""
|
| 26 |
+
Encapsulates all application logic and state for the Gradio interface.
|
| 27 |
+
Manages the embedding model, news data, and training datasets.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, config: AppConfig = AppConfig):
|
| 31 |
+
# --- Dependencies ---
|
| 32 |
+
self.config = config
|
| 33 |
+
|
| 34 |
+
# --- Application State ---
|
| 35 |
+
self.model: Optional[SentenceTransformer] = None
|
| 36 |
+
self.vibe_checker: Optional[VibeChecker] = None
|
| 37 |
+
self.titles: List[str] = [] # Top titles for user selection
|
| 38 |
+
self.target_titles: List[str] = [] # Remaining titles for semantic search target pool
|
| 39 |
+
self.number_list: List[int] = [] # [0, 1, 2, ...] for checkbox indexing
|
| 40 |
+
self.last_hn_dataset: List[List[str]] = [] # Last generated dataset from HN selection
|
| 41 |
+
self.imported_dataset: List[List[str]] = [] # Manually imported dataset
|
| 42 |
+
|
| 43 |
+
# Setup
|
| 44 |
+
os.makedirs(self.config.ARTIFACTS_DIR, exist_ok=True)
|
| 45 |
+
print(f"Created artifact directory: {self.config.ARTIFACTS_DIR}")
|
| 46 |
+
|
| 47 |
+
authenticate_hf(self.config.HF_TOKEN)
|
| 48 |
+
|
| 49 |
+
# Load initial data on startup
|
| 50 |
+
self._initial_load()
|
| 51 |
+
|
| 52 |
+
def _initial_load(self):
|
| 53 |
+
"""Helper to run the refresh function once at startup."""
|
| 54 |
+
print("--- Running Initial Data Load ---")
|
| 55 |
+
self.refresh_data_and_model()
|
| 56 |
+
print("--- Initial Load Complete ---")
|
| 57 |
+
|
| 58 |
+
def _update_vibe_checker(self):
|
| 59 |
+
"""Initializes or updates the VibeChecker with the current model state."""
|
| 60 |
+
if self.model:
|
| 61 |
+
print("Updating VibeChecker instance with the current model.")
|
| 62 |
+
self.vibe_checker = VibeChecker(
|
| 63 |
+
model=self.model,
|
| 64 |
+
query_anchor=self.config.QUERY_ANCHOR,
|
| 65 |
+
task_name=self.config.TASK_NAME
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
self.vibe_checker = None
|
| 69 |
+
|
| 70 |
+
## Data and Model Management ##
|
| 71 |
+
|
| 72 |
+
def refresh_data_and_model(self) -> Tuple[gr.update, gr.update]:
|
| 73 |
+
"""
|
| 74 |
+
1. Reloads the embedding model to clear fine-tuning.
|
| 75 |
+
2. Fetches fresh news data (from cache or web).
|
| 76 |
+
3. Updates the class state and returns Gradio updates for the UI.
|
| 77 |
+
"""
|
| 78 |
+
print("\n" + "=" * 50)
|
| 79 |
+
print("RELOADING MODEL and RE-FETCHING DATA")
|
| 80 |
+
|
| 81 |
+
# Reset dataset state
|
| 82 |
+
self.last_hn_dataset = []
|
| 83 |
+
self.imported_dataset = []
|
| 84 |
+
|
| 85 |
+
# 1. Reload the base embedding model
|
| 86 |
+
try:
|
| 87 |
+
self.model = load_embedding_model(self.config.MODEL_NAME)
|
| 88 |
+
self._update_vibe_checker()
|
| 89 |
+
except Exception as e:
|
| 90 |
+
gr.Error(f"Model load failed: {e}")
|
| 91 |
+
self.model = None
|
| 92 |
+
self._update_vibe_checker()
|
| 93 |
+
return (
|
| 94 |
+
gr.update(choices=[], label="Model Load Failed"),
|
| 95 |
+
gr.update(value=f"CRITICAL ERROR: Model failed to load. {e}")
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# 2. Fetch fresh news data
|
| 99 |
+
news_feed, status_msg = read_hacker_news_rss(self.config)
|
| 100 |
+
titles_out, target_titles_out = [], []
|
| 101 |
+
status_value: str = f"Model and data reloaded. Status: {status_msg}. Click 'Run Fine-Tuning' to begin."
|
| 102 |
+
|
| 103 |
+
if news_feed is not None and news_feed.entries:
|
| 104 |
+
# Use constant for clarity
|
| 105 |
+
titles_out = [item.title for item in news_feed.entries[:self.config.TOP_TITLES_COUNT]]
|
| 106 |
+
target_titles_out = [item.title for item in news_feed.entries[self.config.TOP_TITLES_COUNT:]]
|
| 107 |
+
print(f"Data reloaded: {len(titles_out)} selection titles, {len(target_titles_out)} target titles.")
|
| 108 |
+
else:
|
| 109 |
+
titles_out = ["Error fetching news, check console.", "Could not load feed.", "No data available."]
|
| 110 |
+
gr.Warning(f"Data reload failed. Using error placeholders. Details: {status_msg}")
|
| 111 |
+
|
| 112 |
+
self.titles = titles_out
|
| 113 |
+
self.target_titles = target_titles_out
|
| 114 |
+
self.number_list = list(range(len(self.titles)))
|
| 115 |
+
|
| 116 |
+
# Return Gradio updates for CheckboxGroup and Textbox
|
| 117 |
+
return (
|
| 118 |
+
gr.update(
|
| 119 |
+
choices=self.titles,
|
| 120 |
+
label=f"Hacker News Top {len(self.titles)} (Select your favorites)"
|
| 121 |
+
),
|
| 122 |
+
gr.update(value=status_value)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# --- Import Dataset/Export ---
|
| 126 |
+
def import_additional_dataset(self, file_path: str) -> str:
|
| 127 |
+
if not file_path:
|
| 128 |
+
return "Please upload a CSV file."
|
| 129 |
+
new_dataset, num_imported = [], 0
|
| 130 |
+
try:
|
| 131 |
+
with open(file_path, 'r', newline='', encoding='utf-8') as f:
|
| 132 |
+
reader = csv.reader(f)
|
| 133 |
+
try:
|
| 134 |
+
header = next(reader)
|
| 135 |
+
if not (header and header[0].lower().strip() == 'anchor'):
|
| 136 |
+
f.seek(0)
|
| 137 |
+
except StopIteration:
|
| 138 |
+
return "Error: Uploaded file is empty."
|
| 139 |
+
|
| 140 |
+
for row in reader:
|
| 141 |
+
if len(row) == 3:
|
| 142 |
+
new_dataset.append([s.strip() for s in row])
|
| 143 |
+
num_imported += 1
|
| 144 |
+
if num_imported == 0:
|
| 145 |
+
raise ValueError("No valid [Anchor, Positive, Negative] rows found in the CSV.")
|
| 146 |
+
self.imported_dataset = new_dataset
|
| 147 |
+
return f"Successfully imported {num_imported} additional training triplets."
|
| 148 |
+
except Exception as e:
|
| 149 |
+
gr.Error(f"Import failed. Ensure the CSV format is: [Anchor, Positive, Negative]. Error: {e}")
|
| 150 |
+
return "Import failed. Check console for details."
|
| 151 |
+
|
| 152 |
+
def export_dataset(self) -> Optional[str]:
|
| 153 |
+
if not self.last_hn_dataset:
|
| 154 |
+
gr.Warning("No dataset has been generated from current selection yet. Please run fine-tuning first.")
|
| 155 |
+
return None
|
| 156 |
+
file_path = self.config.DATASET_EXPORT_FILENAME
|
| 157 |
+
try:
|
| 158 |
+
print(f"Exporting dataset to {file_path}...")
|
| 159 |
+
with open(file_path, 'w', newline='', encoding='utf-8') as f:
|
| 160 |
+
writer = csv.writer(f)
|
| 161 |
+
writer.writerow(['Anchor', 'Positive', 'Negative'])
|
| 162 |
+
writer.writerows(self.last_hn_dataset)
|
| 163 |
+
gr.Info(f"Dataset successfully exported to {file_path}")
|
| 164 |
+
return str(file_path)
|
| 165 |
+
except Exception as e:
|
| 166 |
+
gr.Error(f"Failed to export the dataset to CSV. Error: {e}")
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
def download_model(self) -> Optional[str]:
|
| 170 |
+
if not os.path.exists(self.config.OUTPUT_DIR):
|
| 171 |
+
gr.Warning(f"The model directory '{self.config.OUTPUT_DIR}' does not exist. Please run training first.")
|
| 172 |
+
return None
|
| 173 |
+
timestamp = int(time.time())
|
| 174 |
+
try:
|
| 175 |
+
base_name = os.path.join(self.config.ARTIFACTS_DIR, f"embedding_gemma_finetuned_{timestamp}")
|
| 176 |
+
archive_path = shutil.make_archive(
|
| 177 |
+
base_name=base_name,
|
| 178 |
+
format='zip',
|
| 179 |
+
root_dir=self.config.OUTPUT_DIR,
|
| 180 |
+
)
|
| 181 |
+
gr.Info(f"Model files successfully zipped to: {archive_path}")
|
| 182 |
+
return archive_path
|
| 183 |
+
except Exception as e:
|
| 184 |
+
gr.Error(f"Failed to create the model ZIP file. Error: {e}")
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
## Training Logic ##
|
| 188 |
+
def _create_hn_dataset(self, selected_ids: List[int]) -> Tuple[List[List[str]], str, str]:
|
| 189 |
+
"""
|
| 190 |
+
Internal function to generate the [Anchor, Positive, Negative] triplets
|
| 191 |
+
from the user's Hacker News title selection.
|
| 192 |
+
Returns (dataset, favorite_title, non_favorite_title)
|
| 193 |
+
"""
|
| 194 |
+
total_ids, selected_ids = set(self.number_list), set(selected_ids)
|
| 195 |
+
non_selected_ids = total_ids - selected_ids
|
| 196 |
+
is_minority = len(selected_ids) < (len(total_ids) / 2)
|
| 197 |
+
|
| 198 |
+
anchor_ids, pool_ids = (non_selected_ids, list(selected_ids)) if is_minority else (selected_ids, list(non_selected_ids))
|
| 199 |
+
|
| 200 |
+
def get_titles(anchor_id, pool_id):
|
| 201 |
+
return (self.titles[pool_id], self.titles[anchor_id]) if is_minority else (self.titles[anchor_id], self.titles[pool_id])
|
| 202 |
+
|
| 203 |
+
fav_idx = pool_ids[0] if is_minority else list(anchor_ids)[0]
|
| 204 |
+
non_fav_idx = list(anchor_ids)[0] if is_minority else pool_ids[0]
|
| 205 |
+
|
| 206 |
+
hn_dataset = []
|
| 207 |
+
pool_cycler = cycle(pool_ids)
|
| 208 |
+
for anchor_id in sorted(list(anchor_ids)):
|
| 209 |
+
fav, non_fav = get_titles(anchor_id, next(pool_cycler))
|
| 210 |
+
hn_dataset.append([self.config.QUERY_ANCHOR, fav, non_fav])
|
| 211 |
+
|
| 212 |
+
return hn_dataset, self.titles[fav_idx], self.titles[non_fav_idx]
|
| 213 |
+
|
| 214 |
+
def training(self, selected_ids: List[int]) -> str:
|
| 215 |
+
"""
|
| 216 |
+
Generates a training dataset from user selection and runs the fine-tuning process.
|
| 217 |
+
"""
|
| 218 |
+
if self.model is None:
|
| 219 |
+
raise gr.Error("Training failed: Embedding model is not loaded.")
|
| 220 |
+
if not selected_ids:
|
| 221 |
+
raise gr.Error("You must select at least one title.")
|
| 222 |
+
if len(selected_ids) == len(self.number_list):
|
| 223 |
+
raise gr.Error("You can't select all titles; a non-favorite is needed.")
|
| 224 |
+
|
| 225 |
+
hn_dataset, example_fav, _ = self._create_hn_dataset(selected_ids)
|
| 226 |
+
self.last_hn_dataset = hn_dataset
|
| 227 |
+
final_dataset = self.last_hn_dataset + self.imported_dataset
|
| 228 |
+
if not final_dataset:
|
| 229 |
+
raise gr.Error("Training failed: Final dataset is empty.")
|
| 230 |
+
print(f"Combined dataset size: {len(final_dataset)} triplets.")
|
| 231 |
+
|
| 232 |
+
def semantic_search_fn() -> str:
|
| 233 |
+
return get_top_hits(model=self.model, target_titles=self.target_titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
|
| 234 |
+
|
| 235 |
+
result = "### Semantic Search Results (Before Training):\n" + f"{semantic_search_fn()}\n\n"
|
| 236 |
+
print("-" * 50 + "\nStarting Fine-tuning...")
|
| 237 |
+
train_with_dataset(model=self.model, dataset=final_dataset, output_dir=self.config.OUTPUT_DIR, task_name=self.config.TASK_NAME, search_fn=semantic_search_fn)
|
| 238 |
+
self._update_vibe_checker()
|
| 239 |
+
print("Fine-tuning Complete.\n" + "-" * 50)
|
| 240 |
+
|
| 241 |
+
result += "### Semantic Search Results (After Training):\n" + f"{semantic_search_fn()}"
|
| 242 |
+
return result
|
| 243 |
+
|
| 244 |
+
## Vibe Check Logic (Tab 2) ##
|
| 245 |
+
def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
|
| 246 |
+
if not self.vibe_checker:
|
| 247 |
+
gr.Error("Model/VibeChecker not loaded.")
|
| 248 |
+
return "N/A", "Model Error", gr.update(value=self._generate_vibe_html("gray"))
|
| 249 |
+
if not news_text or len(news_text.split()) < 3:
|
| 250 |
+
gr.Warning("Please enter a longer text for a meaningful check.")
|
| 251 |
+
return "N/A", "Please enter text", gr.update(value=self._generate_vibe_html("white"))
|
| 252 |
+
|
| 253 |
+
try:
|
| 254 |
+
vibe_result = self.vibe_checker.check(news_text)
|
| 255 |
+
status = vibe_result.status_html.split('>')[1].split('<')[0] # Extract text from HTML
|
| 256 |
+
return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_html(vibe_result.color_hsl))
|
| 257 |
+
except Exception as e:
|
| 258 |
+
gr.Error(f"Vibe check failed. Error: {e}")
|
| 259 |
+
return "N/A", f"Processing Error: {e}", gr.update(value=self._generate_vibe_html("gray"))
|
| 260 |
+
|
| 261 |
+
def _generate_vibe_html(self, color: str) -> str:
|
| 262 |
+
return f'<div style="background-color: {color}; height: 100px; border-radius: 12px; border: 2px solid #ccc;"></div>'
|
| 263 |
+
|
| 264 |
+
## Mood Reader Logic (Tab 3) ##
|
| 265 |
+
def fetch_and_display_mood_feed(self) -> str:
|
| 266 |
+
if not self.vibe_checker:
|
| 267 |
+
return "**FATAL ERROR:** The Mood Reader failed to initialize. Check console."
|
| 268 |
+
|
| 269 |
+
feed, status = read_hacker_news_rss(self.config)
|
| 270 |
+
if not feed or not feed.entries:
|
| 271 |
+
return f"**An error occurred while fetching the feed:** {status}"
|
| 272 |
+
|
| 273 |
+
scored_entries = []
|
| 274 |
+
for entry in feed.entries:
|
| 275 |
+
title = entry.get('title')
|
| 276 |
+
if not title: continue
|
| 277 |
+
|
| 278 |
+
vibe_result = self.vibe_checker.check(title)
|
| 279 |
+
scored_entries.append({
|
| 280 |
+
"title": title,
|
| 281 |
+
"link": entry.get('link', '#'),
|
| 282 |
+
"comments": entry.get('comments', '#'),
|
| 283 |
+
"published": format_published_time(entry.published_parsed),
|
| 284 |
+
"mood": vibe_result
|
| 285 |
+
})
|
| 286 |
+
|
| 287 |
+
scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True)
|
| 288 |
+
|
| 289 |
+
md = (f"## Hacker News Top Stories (Model: `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}) ⬇️\n"
|
| 290 |
+
f"**Last Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
|
| 291 |
+
"| Vibe | Title | Comments | Published |\n|---|---|---|---|\n")
|
| 292 |
+
|
| 293 |
+
for item in scored_entries:
|
| 294 |
+
md += (f"| {item['mood'].status_html} "
|
| 295 |
+
f"| [{item['title']}]({item['link']}) "
|
| 296 |
+
f"| [Comments]({item['comments']}) "
|
| 297 |
+
f"| {item['published']} |\n")
|
| 298 |
+
return md
|
| 299 |
+
# 🤖 Embedding Gemma Modkit: Fine-Tuning and Mood Reader
|
| 300 |
+
|
| 301 |
+
## Gradio Interface Setup ##
|
| 302 |
+
def build_interface(self) -> gr.Blocks:
|
| 303 |
+
with gr.Blocks(title="Embedding Gemma Modkit") as demo:
|
| 304 |
+
gr.Markdown("# 🤖 Embedding Gemma Modkit: Fine-Tuning and Mood Reader")
|
| 305 |
+
gr.Markdown("See [README](./README.md) for more details.")
|
| 306 |
+
with gr.Tab("🚀 Fine-Tuning & Evaluation"):
|
| 307 |
+
self._build_training_interface()
|
| 308 |
+
with gr.Tab("💡 News Vibe Check"):
|
| 309 |
+
self._build_vibe_check_interface()
|
| 310 |
+
with gr.Tab("📰 Hacker News Mood Reader"):
|
| 311 |
+
self._build_mood_reader_interface()
|
| 312 |
+
return demo
|
| 313 |
+
|
| 314 |
+
def _build_training_interface(self):
|
| 315 |
+
with gr.Column():
|
| 316 |
+
gr.Markdown("## Fine-Tuning & Semantic Search\nSelect titles to fine-tune the model towards making them more similar to **`MY_FAVORITE_NEWS`**.")
|
| 317 |
+
with gr.Row():
|
| 318 |
+
favorite_list = gr.CheckboxGroup(self.titles, type="index", label=f"Hacker News Top {len(self.titles)}", show_select_all=True)
|
| 319 |
+
output = gr.Textbox(lines=24, label="Training and Search Results", value="Click 'Run Fine-Tuning' to begin.")
|
| 320 |
+
with gr.Row():
|
| 321 |
+
clear_reload_btn = gr.Button("Clear & Reload Model/Data")
|
| 322 |
+
run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary")
|
| 323 |
+
gr.Markdown("--- \n ## Dataset & Model Management")
|
| 324 |
+
with gr.Row():
|
| 325 |
+
import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
|
| 326 |
+
download_dataset_btn = gr.Button("💾 Export Last HN Dataset")
|
| 327 |
+
download_model_btn = gr.Button("⬇️ Download Fine-Tuned Model")
|
| 328 |
+
download_status = gr.Markdown("Ready.")
|
| 329 |
+
with gr.Row():
|
| 330 |
+
dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
|
| 331 |
+
model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
|
| 332 |
+
|
| 333 |
+
run_training_btn.click(fn=self.training, inputs=favorite_list, outputs=output)
|
| 334 |
+
clear_reload_btn.click(fn=self.refresh_data_and_model, inputs=None, outputs=[favorite_list, output], queue=False)
|
| 335 |
+
import_file.change(fn=self.import_additional_dataset, inputs=[import_file], outputs=download_status)
|
| 336 |
+
download_dataset_btn.click(lambda: [gr.update(value=None, visible=False), "Generating..."], None, [dataset_output, download_status], queue=False).then(self.export_dataset, None, dataset_output).then(lambda p: [gr.update(visible=p is not None, value=p), "CSV ready." if p else "Export failed."], [dataset_output], [dataset_output, download_status])
|
| 337 |
+
download_model_btn.click(lambda: [gr.update(value=None, visible=False), "Zipping..."], None, [model_output, download_status], queue=False).then(self.download_model, None, model_output).then(lambda p: [gr.update(visible=p is not None, value=p), "ZIP ready." if p else "Zipping failed."], [model_output], [model_output, download_status])
|
| 338 |
+
|
| 339 |
+
def _build_vibe_check_interface(self):
|
| 340 |
+
with gr.Column():
|
| 341 |
+
gr.Markdown(f"## News Vibe Check Mood Lamp\nEnter text to see its similarity to **`{self.config.QUERY_ANCHOR}`**.\n**Vibe Key:** Green = High, Red = Low")
|
| 342 |
+
news_input = gr.Textbox(label="Enter News Title or Summary", lines=3)
|
| 343 |
+
vibe_check_btn = gr.Button("Check Vibe", variant="primary")
|
| 344 |
+
with gr.Row():
|
| 345 |
+
vibe_color_block = gr.HTML(value=self._generate_vibe_html("white"), label="Mood Lamp")
|
| 346 |
+
with gr.Column():
|
| 347 |
+
vibe_score = gr.Textbox(label="Cosine Similarity Score", value="N/A", interactive=False)
|
| 348 |
+
vibe_status = gr.Textbox(label="Vibe Status", value="Enter text and click 'Check Vibe'", interactive=False, lines=2)
|
| 349 |
+
vibe_check_btn.click(fn=self.get_vibe_check, inputs=[news_input], outputs=[vibe_score, vibe_status, vibe_color_block])
|
| 350 |
+
|
| 351 |
+
def _build_mood_reader_interface(self):
|
| 352 |
+
with gr.Column():
|
| 353 |
+
gr.Markdown(f"## Live Hacker News Feed Vibe\nThis feed uses the current model (base or fine-tuned) to score the vibe of live HN stories against **`{self.config.QUERY_ANCHOR}`**.")
|
| 354 |
+
feed_output = gr.Markdown(value="Click 'Refresh Feed' to load stories.", label="Latest Stories")
|
| 355 |
+
refresh_button = gr.Button("Refresh Feed 🔄", size="lg", variant="primary")
|
| 356 |
+
refresh_button.click(fn=self.fetch_and_display_mood_feed, inputs=None, outputs=feed_output)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
if __name__ == "__main__":
|
| 360 |
+
app = HackerNewsFineTuner(AppConfig)
|
| 361 |
+
demo = app.build_interface()
|
| 362 |
+
print("Starting Gradio App...")
|
| 363 |
+
demo.launch()
|
| 364 |
+
|
cli_mood_reader.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import shutil
|
| 4 |
+
import click
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
# --- Core Logic Imports ---
|
| 9 |
+
# These modules contain the application's functionality.
|
| 10 |
+
from config import AppConfig
|
| 11 |
+
from hn_mood_reader import HnMoodReader, FeedEntry
|
| 12 |
+
from vibe_logic import VIBE_THRESHOLDS
|
| 13 |
+
|
| 14 |
+
# --- Helper Functions ---
|
| 15 |
+
|
| 16 |
+
def get_status_text_and_color(score: float) -> (str, str):
|
| 17 |
+
"""
|
| 18 |
+
Determines the plain text status and a corresponding color for a given score.
|
| 19 |
+
"""
|
| 20 |
+
clamped_score = max(0.0, min(1.0, score))
|
| 21 |
+
|
| 22 |
+
# Define colors for different vibe levels
|
| 23 |
+
color_map = {
|
| 24 |
+
"VIBE:HIGH": "green",
|
| 25 |
+
"VIBE:GOOD": "cyan",
|
| 26 |
+
"VIBE:FLAT": "yellow",
|
| 27 |
+
"VIBE:LOW": "red"
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
for threshold in VIBE_THRESHOLDS:
|
| 31 |
+
if clamped_score >= threshold.score:
|
| 32 |
+
status = threshold.status.split(" ")[-1].replace(' ', '')
|
| 33 |
+
return status, color_map.get(status, "white")
|
| 34 |
+
|
| 35 |
+
# Fallback for the lowest score
|
| 36 |
+
status = VIBE_THRESHOLDS[-1].status.split(" ")[-1].replace(' ', '')
|
| 37 |
+
return status, color_map.get(status, "white")
|
| 38 |
+
|
| 39 |
+
def initialize_reader(model_name: str) -> HnMoodReader:
|
| 40 |
+
"""
|
| 41 |
+
Initializes the HnMoodReader instance with the specified model.
|
| 42 |
+
Exits the script if the model fails to load.
|
| 43 |
+
"""
|
| 44 |
+
click.echo(f"Initializing mood reader with model: '{model_name}'...", err=True)
|
| 45 |
+
try:
|
| 46 |
+
reader = HnMoodReader(model_name=model_name)
|
| 47 |
+
click.secho("✅ Model loaded successfully.", fg="green", err=True)
|
| 48 |
+
return reader
|
| 49 |
+
except Exception as e:
|
| 50 |
+
click.secho(f"❌ FATAL: Could not initialize model '{model_name}'.", fg="red", err=True)
|
| 51 |
+
click.secho(f" Error: {e}", fg="red", err=True)
|
| 52 |
+
sys.exit(1) # Exit with a non-zero code to indicate failure
|
| 53 |
+
|
| 54 |
+
def display_feed(scored_entries: List[FeedEntry], top: int, offset: int, model_name: str):
|
| 55 |
+
"""Clears the screen and displays the current slice of the feed."""
|
| 56 |
+
click.clear()
|
| 57 |
+
|
| 58 |
+
# Get terminal width, but default to 80 if it's too narrow
|
| 59 |
+
# to avoid breaking the layout.
|
| 60 |
+
try:
|
| 61 |
+
terminal_width = shutil.get_terminal_size()[0]
|
| 62 |
+
except OSError: # Handle cases where terminal size can't be determined (e.g., in a pipe)
|
| 63 |
+
terminal_width = 80
|
| 64 |
+
|
| 65 |
+
click.echo(f"📰 Hacker News Mood Reader")
|
| 66 |
+
click.echo(f" Model: {model_name}")
|
| 67 |
+
click.echo(f" Showing {offset + 1}-{min(offset + top, len(scored_entries))} of {len(scored_entries)} stories")
|
| 68 |
+
click.secho("=" * terminal_width, fg="blue")
|
| 69 |
+
|
| 70 |
+
header = f"{'VIBE':<5} | {'SCORE':<7} | {'PUBLISHED':<16} | {'TITLE'}"
|
| 71 |
+
click.secho(header, bold=True)
|
| 72 |
+
click.secho("-" * terminal_width, fg="blue")
|
| 73 |
+
|
| 74 |
+
# Calculate the fixed width of the columns before the title
|
| 75 |
+
# Vibe: 5
|
| 76 |
+
# Score: | + ' ' + '0.0000' + ' ' = 9
|
| 77 |
+
# Published: | + ' ' + 'YYYY-MM-DD HH:MM' + ' ' + | + ' ' = 21
|
| 78 |
+
# Total fixed width = 5 + 9 + 21 = 35
|
| 79 |
+
fixed_width = 35
|
| 80 |
+
max_title_width = terminal_width - fixed_width
|
| 81 |
+
# --- MODIFICATION END ---
|
| 82 |
+
|
| 83 |
+
if not scored_entries:
|
| 84 |
+
click.echo("No entries found in the feed.")
|
| 85 |
+
else:
|
| 86 |
+
# Display the current "page" of entries based on the offset
|
| 87 |
+
for entry in scored_entries[offset:offset + top]:
|
| 88 |
+
status, color = get_status_text_and_color(entry.mood.raw_score)
|
| 89 |
+
|
| 90 |
+
# --- MODIFICATION: VIBE width changed from 12 to 5 ---
|
| 91 |
+
# Also ensure the status text itself is truncated if it's longer than 5
|
| 92 |
+
truncated_status = status[5:]
|
| 93 |
+
vibe_part = click.style(f"{truncated_status:<5}", fg=color)
|
| 94 |
+
|
| 95 |
+
score_part = f"| {entry.mood.raw_score:>.4f} "
|
| 96 |
+
published_part = f"| {entry.published_time_str:<16} | "
|
| 97 |
+
|
| 98 |
+
# --- Title Truncation Logic ---
|
| 99 |
+
full_title = entry.title
|
| 100 |
+
|
| 101 |
+
if len(full_title) > max_title_width:
|
| 102 |
+
# Truncate and add ellipsis, reserving 3 chars for '...'
|
| 103 |
+
title_part = full_title[:max_title_width - 3] + "..."
|
| 104 |
+
else:
|
| 105 |
+
title_part = full_title
|
| 106 |
+
# --- End Title Truncation ---
|
| 107 |
+
|
| 108 |
+
# Combine parts and print
|
| 109 |
+
full_line = vibe_part + score_part + published_part + title_part
|
| 110 |
+
click.echo(full_line)
|
| 111 |
+
|
| 112 |
+
click.secho("-" * terminal_width, fg="blue")
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# --- Main Application Logic (CLI Command) ---
|
| 116 |
+
|
| 117 |
+
@click.command()
|
| 118 |
+
@click.option(
|
| 119 |
+
"-m", "--model",
|
| 120 |
+
help="Name of the Sentence Transformer model from Hugging Face. Overrides MOOD_MODEL env var.",
|
| 121 |
+
default=None,
|
| 122 |
+
show_default=False
|
| 123 |
+
)
|
| 124 |
+
@click.option(
|
| 125 |
+
"-n", "--top",
|
| 126 |
+
help="Number of stories to display on screen at once.",
|
| 127 |
+
default=15,
|
| 128 |
+
type=int,
|
| 129 |
+
show_default=True
|
| 130 |
+
)
|
| 131 |
+
def main(model, top):
|
| 132 |
+
"""
|
| 133 |
+
Fetch and display Hacker News stories scored by a sentence-embedding model.
|
| 134 |
+
Runs continuously. Use arrow keys to scroll, [SPACE] to refresh, [q] to quit.
|
| 135 |
+
"""
|
| 136 |
+
# --- State Management ---
|
| 137 |
+
model_name = model or os.environ.get("MOOD_MODEL") or AppConfig.DEFAULT_MOOD_READER_MODEL
|
| 138 |
+
reader = initialize_reader(model_name)
|
| 139 |
+
scored_entries: List[FeedEntry] = []
|
| 140 |
+
scroll_offset = 0
|
| 141 |
+
|
| 142 |
+
# --- Initial Fetch ---
|
| 143 |
+
click.echo("Fetching initial feed...", err=True)
|
| 144 |
+
try:
|
| 145 |
+
scored_entries = reader.fetch_and_score_feed()
|
| 146 |
+
except Exception as e:
|
| 147 |
+
click.secho(f"❌ ERROR: Initial fetch failed: {e}", fg="red", err=True)
|
| 148 |
+
|
| 149 |
+
# --- Main Loop ---
|
| 150 |
+
while True:
|
| 151 |
+
display_feed(scored_entries, top, scroll_offset, reader.model_name)
|
| 152 |
+
|
| 153 |
+
click.secho("Use [↑|↓] to scroll, [SPACE] to refresh, or [q] to quit.", bold=True, err=True)
|
| 154 |
+
key = click.getchar()
|
| 155 |
+
|
| 156 |
+
if key == ' ':
|
| 157 |
+
click.echo("Refreshing feed...", err=True)
|
| 158 |
+
try:
|
| 159 |
+
scored_entries = reader.fetch_and_score_feed()
|
| 160 |
+
scroll_offset = 0 # Reset scroll on refresh
|
| 161 |
+
except Exception as e:
|
| 162 |
+
click.secho(f"❌ ERROR: Refresh failed: {e}", fg="red", err=True)
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
elif key in ('q', 'Q'):
|
| 166 |
+
click.echo("Exiting.")
|
| 167 |
+
break
|
| 168 |
+
|
| 169 |
+
# Arrow key handling for scrolling (might produce escape sequences)
|
| 170 |
+
elif key == '\x1b[A': # Up Arrow
|
| 171 |
+
scroll_offset = max(0, scroll_offset - 1)
|
| 172 |
+
elif key == '\x1b[B': # Down Arrow
|
| 173 |
+
# Prevent scrolling past the last page
|
| 174 |
+
scroll_offset = min(scroll_offset + 1, max(0, len(scored_entries) - top))
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
main()
|
| 178 |
+
|
| 179 |
+
|
config.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Final
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# --- Base Directory Definition ---
|
| 6 |
+
# Use Path for modern, OS-agnostic path handling
|
| 7 |
+
ARTIFACTS_DIR: Final[Path] = Path("artifacts")
|
| 8 |
+
|
| 9 |
+
class AppConfig:
|
| 10 |
+
"""
|
| 11 |
+
Central configuration class for the Hacker News Fine-Tuner application.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
# --- Directory/Environment Configuration ---
|
| 15 |
+
ARTIFACTS_DIR: Final[Path] = ARTIFACTS_DIR
|
| 16 |
+
|
| 17 |
+
# Environment variable for Hugging Face token (used by model_trainer)
|
| 18 |
+
HF_TOKEN: Final[str | None] = os.getenv('HF_TOKEN')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# --- Caching/Data Fetching Configuration ---
|
| 22 |
+
HN_RSS_URL: Final[str] = "https://news.ycombinator.com/rss"
|
| 23 |
+
|
| 24 |
+
# Filename for the pickled cache data (using Path.joinpath)
|
| 25 |
+
CACHE_FILE: Final[Path] = ARTIFACTS_DIR.joinpath("hacker_news_cache.pkl")
|
| 26 |
+
|
| 27 |
+
# Cache duration set to 30 minutes (1800 seconds)
|
| 28 |
+
CACHE_DURATION_SECONDS: Final[int] = 60 * 30
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# --- Model/Training Configuration ---
|
| 32 |
+
|
| 33 |
+
# Name of the pre-trained embedding model
|
| 34 |
+
MODEL_NAME: Final[str] = 'google/embeddinggemma-300M'
|
| 35 |
+
|
| 36 |
+
# Task name for prompting the embedding model (e.g., for instruction tuning)
|
| 37 |
+
TASK_NAME: Final[str] = "Classification"
|
| 38 |
+
|
| 39 |
+
# Output directory for the fine-tuned model
|
| 40 |
+
OUTPUT_DIR: Final[Path] = ARTIFACTS_DIR.joinpath("embedding-gemma-finetuned-hn")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# --- Gradio/App-Specific Configuration ---
|
| 44 |
+
|
| 45 |
+
# Anchor text used for contrastive learning dataset generation
|
| 46 |
+
QUERY_ANCHOR: Final[str] = "MY_FAVORITE_NEWS"
|
| 47 |
+
|
| 48 |
+
# Number of titles shown for user selection in the Gradio interface
|
| 49 |
+
TOP_TITLES_COUNT: Final[int] = 10
|
| 50 |
+
|
| 51 |
+
# Default export path for the dataset CSV
|
| 52 |
+
DATASET_EXPORT_FILENAME: Final[Path] = ARTIFACTS_DIR.joinpath("training_dataset.csv")
|
| 53 |
+
|
| 54 |
+
# Default model for the standalone Mood Reader tab
|
| 55 |
+
DEFAULT_MOOD_READER_MODEL: Final[str] = "bebechien/embedding-gemma-finetuned-hn"
|
| 56 |
+
|
data_fetcher.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import feedparser
|
| 2 |
+
import pickle
|
| 3 |
+
import os
|
| 4 |
+
import time
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import Tuple, Any, Optional
|
| 7 |
+
|
| 8 |
+
# Assuming AppConfig is passed in via dependency injection in the refactored main app.
|
| 9 |
+
|
| 10 |
+
def format_published_time(published_parsed: Optional[time.struct_time]) -> str:
|
| 11 |
+
"""Safely converts a feedparser time struct to a formatted string."""
|
| 12 |
+
if published_parsed:
|
| 13 |
+
try:
|
| 14 |
+
dt_obj = datetime.fromtimestamp(time.mktime(published_parsed))
|
| 15 |
+
return dt_obj.strftime('%Y-%m-%d %H:%M')
|
| 16 |
+
except Exception:
|
| 17 |
+
return 'N/A'
|
| 18 |
+
return 'N/A'
|
| 19 |
+
|
| 20 |
+
def load_feed_from_cache(config: Any) -> Tuple[Optional[Any], str]:
|
| 21 |
+
"""Attempts to load a feed object from the cache file if it exists and is not expired."""
|
| 22 |
+
if not os.path.exists(config.CACHE_FILE):
|
| 23 |
+
return None, "Cache file not found."
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
# Check cache age
|
| 27 |
+
file_age_seconds = time.time() - os.path.getmtime(config.CACHE_FILE)
|
| 28 |
+
|
| 29 |
+
if file_age_seconds > config.CACHE_DURATION_SECONDS:
|
| 30 |
+
# The cache is too old
|
| 31 |
+
return None, f"Cache expired ({file_age_seconds:.0f}s old, limit is {config.CACHE_DURATION_SECONDS}s)."
|
| 32 |
+
|
| 33 |
+
with open(config.CACHE_FILE, 'rb') as f:
|
| 34 |
+
feed = pickle.load(f)
|
| 35 |
+
return feed, f"Loaded successfully from cache (Age: {file_age_seconds:.0f}s)."
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
# If loading fails, treat it as a miss and attempt to clean up
|
| 39 |
+
print(f"Warning: Failed to load cache file. Deleting corrupted cache. Reason: {e}")
|
| 40 |
+
try:
|
| 41 |
+
os.remove(config.CACHE_FILE)
|
| 42 |
+
except OSError:
|
| 43 |
+
pass # Ignore if removal fails
|
| 44 |
+
return None, "Cache file corrupted or invalid. Will re-fetch."
|
| 45 |
+
|
| 46 |
+
def save_feed_to_cache(config: Any, feed: Any) -> None:
|
| 47 |
+
"""Saves the fetched feed object to the cache file."""
|
| 48 |
+
try:
|
| 49 |
+
with open(config.CACHE_FILE, 'wb') as f:
|
| 50 |
+
pickle.dump(feed, f)
|
| 51 |
+
print(f"Successfully saved new feed data to cache: {config.CACHE_FILE}")
|
| 52 |
+
except Exception as e:
|
| 53 |
+
print(f"Error saving to cache: {e}")
|
| 54 |
+
|
| 55 |
+
def read_hacker_news_rss(config: Any) -> Tuple[Optional[Any], str]:
|
| 56 |
+
"""
|
| 57 |
+
Reads and parses the Hacker News RSS feed, using a cache if available.
|
| 58 |
+
Returns the feedparser object and a status message.
|
| 59 |
+
"""
|
| 60 |
+
url = config.HN_RSS_URL
|
| 61 |
+
print(f"Attempting to fetch and parse RSS feed from: {url}")
|
| 62 |
+
print("-" * 50)
|
| 63 |
+
|
| 64 |
+
# 1. Attempt to load from cache
|
| 65 |
+
feed, cache_status = load_feed_from_cache(config)
|
| 66 |
+
print(f"Cache Status: {cache_status}")
|
| 67 |
+
|
| 68 |
+
# 2. If cache miss or stale, fetch from web
|
| 69 |
+
if feed is None:
|
| 70 |
+
print("Starting network fetch...")
|
| 71 |
+
try:
|
| 72 |
+
# Use feedparser to fetch and parse the feed
|
| 73 |
+
feed = feedparser.parse(url)
|
| 74 |
+
|
| 75 |
+
if feed.status >= 400:
|
| 76 |
+
status_msg = f"Error fetching the feed. HTTP Status: {feed.status}"
|
| 77 |
+
print(status_msg)
|
| 78 |
+
return None, status_msg
|
| 79 |
+
|
| 80 |
+
if feed.bozo:
|
| 81 |
+
# Bozo is set if any error occurred, even non-critical ones.
|
| 82 |
+
print(f"Warning: Failed to fully parse the feed. Reason: {feed.get('bozo_exception')}")
|
| 83 |
+
|
| 84 |
+
# 3. If fetch successful, save new data to cache
|
| 85 |
+
if feed.entries:
|
| 86 |
+
save_feed_to_cache(config, feed)
|
| 87 |
+
status_msg = f"Successfully fetched and cached {len(feed.entries)} entries."
|
| 88 |
+
else:
|
| 89 |
+
status_msg = "Fetch successful, but no entries found in the feed."
|
| 90 |
+
print(status_msg)
|
| 91 |
+
feed = None # Ensure feed is None if no entries
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
status_msg = f"An unexpected error occurred during network processing: {e}"
|
| 95 |
+
print(status_msg)
|
| 96 |
+
return None, status_msg
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
status_msg = cache_status
|
| 100 |
+
|
| 101 |
+
return feed, status_msg
|
| 102 |
+
|
| 103 |
+
# Example usage (not part of the refactored module's purpose but good for testing)
|
| 104 |
+
if __name__ == '__main__':
|
| 105 |
+
from config import AppConfig
|
| 106 |
+
feed, status = read_hacker_news_rss(AppConfig)
|
| 107 |
+
if feed and feed.entries:
|
| 108 |
+
print(f"\nFetched {len(feed.entries)} entries. Top 3 titles:")
|
| 109 |
+
for entry in feed.entries[:3]:
|
| 110 |
+
print(f"- {entry.title}")
|
| 111 |
+
else:
|
| 112 |
+
print(f"Could not fetch the feed. Status: {status}")
|
flask_app.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
from flask import Flask, render_template
|
| 8 |
+
|
| 9 |
+
# Your existing config and core logic
|
| 10 |
+
from config import AppConfig
|
| 11 |
+
from hn_mood_reader import HnMoodReader, FeedEntry
|
| 12 |
+
|
| 13 |
+
# --- Flask App Initialization ---
|
| 14 |
+
app = Flask(__name__)
|
| 15 |
+
|
| 16 |
+
# --- Global Cache for the Model ---
|
| 17 |
+
global_reader: Optional[HnMoodReader] = None
|
| 18 |
+
|
| 19 |
+
def initialize_reader() -> HnMoodReader:
|
| 20 |
+
"""
|
| 21 |
+
Initializes the HnMoodReader instance. This function is called once
|
| 22 |
+
when the application starts.
|
| 23 |
+
"""
|
| 24 |
+
print("Attempting to initialize the mood reader model...")
|
| 25 |
+
model_name = os.environ.get("MOOD_MODEL", AppConfig.DEFAULT_MOOD_READER_MODEL)
|
| 26 |
+
try:
|
| 27 |
+
reader = HnMoodReader(model_name=model_name)
|
| 28 |
+
print("Model loaded successfully.")
|
| 29 |
+
return reader
|
| 30 |
+
except Exception as e:
|
| 31 |
+
# If the model fails to load, print a fatal error and exit the app.
|
| 32 |
+
print(f"FATAL: Could not initialize model '{model_name}'. Error: {e}", file=sys.stderr)
|
| 33 |
+
sys.exit(1) # Exit with a non-zero code to indicate failure
|
| 34 |
+
|
| 35 |
+
# --- Initialize the reader as soon as the app starts ---
|
| 36 |
+
global_reader = initialize_reader()
|
| 37 |
+
|
| 38 |
+
# --- Flask Route ---
|
| 39 |
+
@app.route('/')
|
| 40 |
+
def index():
|
| 41 |
+
"""Main page route."""
|
| 42 |
+
try:
|
| 43 |
+
scored_entries = global_reader.fetch_and_score_feed()
|
| 44 |
+
|
| 45 |
+
return render_template(
|
| 46 |
+
'index.html',
|
| 47 |
+
entries=scored_entries,
|
| 48 |
+
model_name=global_reader.model_name,
|
| 49 |
+
last_updated=datetime.now().strftime('%H:%M:%S')
|
| 50 |
+
)
|
| 51 |
+
except Exception as e:
|
| 52 |
+
# Render a simple error page if something goes wrong
|
| 53 |
+
return render_template('error.html', error=str(e)), 500
|
| 54 |
+
|
| 55 |
+
if __name__ == '__main__':
|
| 56 |
+
# Using debug=False is recommended for a stable display
|
| 57 |
+
# use_reloader=False prevents the app from initializing the model twice in debug mode
|
| 58 |
+
app.run(host='0.0.0.0', port=5000, debug=False, use_reloader=False)
|
hn_mood_reader.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# hn_mood_reader.py
|
| 2 |
+
|
| 3 |
+
import feedparser
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import List
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
# Assuming these are in separate files as in the original structure
|
| 10 |
+
from config import AppConfig
|
| 11 |
+
from data_fetcher import format_published_time
|
| 12 |
+
from vibe_logic import VibeChecker, VibeResult
|
| 13 |
+
|
| 14 |
+
# --- Data Structures ---
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class FeedEntry:
|
| 17 |
+
"""Stores necessary data for a single HN story, including its calculated mood."""
|
| 18 |
+
title: str
|
| 19 |
+
link: str
|
| 20 |
+
comments_link: str
|
| 21 |
+
published_time_str: str
|
| 22 |
+
mood: VibeResult
|
| 23 |
+
|
| 24 |
+
# --- Core Logic Class ---
|
| 25 |
+
class HnMoodReader:
|
| 26 |
+
"""Handles model initialization and mood scoring for Hacker News titles."""
|
| 27 |
+
def __init__(self, model_name: str):
|
| 28 |
+
try:
|
| 29 |
+
from sentence_transformers import SentenceTransformer
|
| 30 |
+
except ImportError as e:
|
| 31 |
+
raise ImportError("Please install 'sentence-transformers'") from e
|
| 32 |
+
|
| 33 |
+
print(f"Initializing SentenceTransformer with model: {model_name}...")
|
| 34 |
+
self.model = SentenceTransformer(model_name, truncate_dim=128)
|
| 35 |
+
print("Model initialized successfully.")
|
| 36 |
+
|
| 37 |
+
self.vibe_checker = VibeChecker(
|
| 38 |
+
model=self.model,
|
| 39 |
+
query_anchor=AppConfig.QUERY_ANCHOR,
|
| 40 |
+
task_name=AppConfig.TASK_NAME
|
| 41 |
+
)
|
| 42 |
+
self.model_name = model_name
|
| 43 |
+
|
| 44 |
+
def _get_mood_result(self, title: str) -> VibeResult:
|
| 45 |
+
"""Calculates the mood for a title using the VibeChecker."""
|
| 46 |
+
return self.vibe_checker.check(title)
|
| 47 |
+
|
| 48 |
+
def fetch_and_score_feed(self) -> List[FeedEntry]:
|
| 49 |
+
"""Fetches, scores, and sorts entries from the HN RSS feed."""
|
| 50 |
+
feed = feedparser.parse(AppConfig.HN_RSS_URL)
|
| 51 |
+
if feed.bozo:
|
| 52 |
+
raise IOError(f"Error parsing feed from {AppConfig.HN_RSS_URL}.")
|
| 53 |
+
|
| 54 |
+
scored_entries: List[FeedEntry] = []
|
| 55 |
+
for entry in feed.entries:
|
| 56 |
+
title, link = entry.get('title'), entry.get('link')
|
| 57 |
+
if not title or not link:
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
scored_entries.append(
|
| 61 |
+
FeedEntry(
|
| 62 |
+
title=title,
|
| 63 |
+
link=link,
|
| 64 |
+
comments_link=entry.get('comments', '#'),
|
| 65 |
+
published_time_str=format_published_time(entry.published_parsed),
|
| 66 |
+
mood=self._get_mood_result(title)
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
scored_entries.sort(key=lambda x: x.mood.raw_score, reverse=True)
|
| 71 |
+
return scored_entries
|
model_trainer.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import login
|
| 2 |
+
from sentence_transformers import SentenceTransformer, util
|
| 3 |
+
from datasets import Dataset
|
| 4 |
+
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
| 5 |
+
from sentence_transformers.losses import MultipleNegativesRankingLoss
|
| 6 |
+
from transformers import TrainerCallback, TrainingArguments
|
| 7 |
+
from typing import List, Callable, Optional
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# --- Model/Utility Functions ---
|
| 11 |
+
|
| 12 |
+
def authenticate_hf(token: Optional[str]) -> None:
|
| 13 |
+
"""Logs into the Hugging Face Hub."""
|
| 14 |
+
if token:
|
| 15 |
+
print("Logging into Hugging Face Hub...")
|
| 16 |
+
login(token=token)
|
| 17 |
+
else:
|
| 18 |
+
print("Skipping Hugging Face login: HF_TOKEN not set.")
|
| 19 |
+
|
| 20 |
+
def load_embedding_model(model_name: str) -> SentenceTransformer:
|
| 21 |
+
"""Initializes the Sentence Transformer model."""
|
| 22 |
+
print(f"Loading Sentence Transformer model: {model_name}")
|
| 23 |
+
try:
|
| 24 |
+
model = SentenceTransformer(model_name)
|
| 25 |
+
print("Model loaded successfully.")
|
| 26 |
+
return model
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"Error loading Sentence Transformer model {model_name}: {e}")
|
| 29 |
+
raise
|
| 30 |
+
|
| 31 |
+
def get_top_hits(
|
| 32 |
+
model: SentenceTransformer,
|
| 33 |
+
target_titles: List[str],
|
| 34 |
+
task_name: str,
|
| 35 |
+
query: str = "MY_FAVORITE_NEWS",
|
| 36 |
+
top_k: int = 5
|
| 37 |
+
) -> str:
|
| 38 |
+
"""Performs semantic search on target_titles and returns a formatted result string."""
|
| 39 |
+
if not target_titles:
|
| 40 |
+
return "No target titles available for search."
|
| 41 |
+
|
| 42 |
+
# Encode the query
|
| 43 |
+
query_embedding = model.encode(query, prompt_name=task_name)
|
| 44 |
+
|
| 45 |
+
# Encode the target titles (only done once per call)
|
| 46 |
+
title_embeddings = model.encode(target_titles, prompt_name=task_name)
|
| 47 |
+
|
| 48 |
+
# Perform semantic search
|
| 49 |
+
top_hits = util.semantic_search(query_embedding, title_embeddings, top_k=top_k)[0]
|
| 50 |
+
|
| 51 |
+
result = []
|
| 52 |
+
for hit in top_hits:
|
| 53 |
+
title = target_titles[hit['corpus_id']]
|
| 54 |
+
score = hit['score']
|
| 55 |
+
result.append(f"[{title}] {score:.4f}")
|
| 56 |
+
|
| 57 |
+
return "\n".join(result)
|
| 58 |
+
|
| 59 |
+
# --- Training Class and Function ---
|
| 60 |
+
|
| 61 |
+
class EvaluationCallback(TrainerCallback):
|
| 62 |
+
"""
|
| 63 |
+
A callback that runs the semantic search evaluation at the end of each log step.
|
| 64 |
+
The search function is passed in during initialization.
|
| 65 |
+
"""
|
| 66 |
+
def __init__(self, search_fn: Callable[[], str]):
|
| 67 |
+
self.search_fn = search_fn
|
| 68 |
+
|
| 69 |
+
def on_log(self, args: TrainingArguments, state, control, **kwargs):
|
| 70 |
+
print(f"Step {state.global_step} finished. Running evaluation:")
|
| 71 |
+
print(f"\n{self.search_fn()}\n")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def train_with_dataset(
|
| 75 |
+
model: SentenceTransformer,
|
| 76 |
+
dataset: List[List[str]],
|
| 77 |
+
output_dir: Path,
|
| 78 |
+
task_name: str,
|
| 79 |
+
search_fn: Callable[[], str]
|
| 80 |
+
) -> None:
|
| 81 |
+
"""
|
| 82 |
+
Fine-tunes the provided Sentence Transformer MODEL on the dataset.
|
| 83 |
+
|
| 84 |
+
The dataset should be a list of lists: [[anchor, positive, negative], ...].
|
| 85 |
+
"""
|
| 86 |
+
# Convert to Hugging Face Dataset format
|
| 87 |
+
data_as_dicts = [
|
| 88 |
+
{"anchor": row[0], "positive": row[1], "negative": row[2]}
|
| 89 |
+
for row in dataset
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
train_dataset = Dataset.from_list(data_as_dicts)
|
| 93 |
+
|
| 94 |
+
# Use MultipleNegativesRankingLoss, suitable for contrastive learning
|
| 95 |
+
loss = MultipleNegativesRankingLoss(model)
|
| 96 |
+
|
| 97 |
+
# Note: SentenceTransformer models typically have a 'prompts' attribute
|
| 98 |
+
# which we need to access for the training arguments.
|
| 99 |
+
prompts = getattr(model, 'prompts', {}).get(task_name)
|
| 100 |
+
if not prompts:
|
| 101 |
+
print(f"Warning: Could not find prompts for task '{task_name}' in model. Training may be less effective.")
|
| 102 |
+
# Fallback to an empty list or appropriate default if required by the model's structure
|
| 103 |
+
prompts = []
|
| 104 |
+
|
| 105 |
+
args = SentenceTransformerTrainingArguments(
|
| 106 |
+
output_dir=output_dir,
|
| 107 |
+
prompts=prompts,
|
| 108 |
+
num_train_epochs=4,
|
| 109 |
+
per_device_train_batch_size=1,
|
| 110 |
+
learning_rate=2e-5,
|
| 111 |
+
warmup_ratio=0.1,
|
| 112 |
+
logging_steps=train_dataset.num_rows,
|
| 113 |
+
report_to="none",
|
| 114 |
+
save_strategy="no" # No saving during training, only at the end
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
trainer = SentenceTransformerTrainer(
|
| 118 |
+
model=model,
|
| 119 |
+
args=args,
|
| 120 |
+
train_dataset=train_dataset,
|
| 121 |
+
loss=loss,
|
| 122 |
+
callbacks=[EvaluationCallback(search_fn)]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
trainer.train()
|
| 126 |
+
|
| 127 |
+
print("Training finished. Model weights are updated in memory.")
|
| 128 |
+
|
| 129 |
+
# Save the final fine-tuned model
|
| 130 |
+
trainer.save_model()
|
| 131 |
+
|
| 132 |
+
print(f"Model saved locally to: {output_dir}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
beautifulsoup4
|
| 3 |
+
datasets
|
| 4 |
+
feedparser
|
| 5 |
+
flask
|
| 6 |
+
gradio
|
| 7 |
+
html_to_markdown
|
| 8 |
+
sentence-transformers
|
| 9 |
+
git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
|
templates/error.html
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<title>Error</title>
|
| 6 |
+
<style>body { background-color: #121212; color: #ff5555; font-family: sans-serif; padding: 2rem; }</style>
|
| 7 |
+
</head>
|
| 8 |
+
<body>
|
| 9 |
+
<h1>An Error Occurred</h1>
|
| 10 |
+
<p>Could not load the feed. See server logs for details.</p>
|
| 11 |
+
<pre>{{ error }}</pre>
|
| 12 |
+
</body>
|
| 13 |
+
</html>
|
templates/index.html
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<meta http-equiv="refresh" content="300">
|
| 7 |
+
<title>Hacker News Vibe Reader</title>
|
| 8 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 9 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 10 |
+
<link href="https://fonts.googleapis.com/css2?family=Press+Start+2P&display=swap" rel="stylesheet">
|
| 11 |
+
|
| 12 |
+
<style>
|
| 13 |
+
body {
|
| 14 |
+
/* Use the imported pixel font */
|
| 15 |
+
font-family: 'Press Start 2P', cursive;
|
| 16 |
+
background-color: #1a1a1a; /* Dark background */
|
| 17 |
+
color: #00ff00; /* Classic green terminal text */
|
| 18 |
+
margin: 0;
|
| 19 |
+
padding: 1rem;
|
| 20 |
+
/* Prevents font anti-aliasing to keep it crisp */
|
| 21 |
+
-webkit-font-smoothing: none;
|
| 22 |
+
-moz-osx-font-smoothing: grayscale;
|
| 23 |
+
/* Ensures emojis are rendered as pixels */
|
| 24 |
+
image-rendering: pixelated;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
.container {
|
| 28 |
+
max-width: 900px;
|
| 29 |
+
margin: 1rem auto;
|
| 30 |
+
border: 2px solid #00ff00;
|
| 31 |
+
padding: 1.5rem;
|
| 32 |
+
/* Hard, blocky shadow for a retro UI feel */
|
| 33 |
+
box-shadow: 5px 5px 0px #005f00;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
h1 {
|
| 37 |
+
font-size: 1.2rem;
|
| 38 |
+
color: #ffffff;
|
| 39 |
+
text-shadow: 2px 2px #00ff00;
|
| 40 |
+
margin-top: 0;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
.meta-info {
|
| 44 |
+
font-size: 0.7rem;
|
| 45 |
+
color: #8cff8c;
|
| 46 |
+
margin-bottom: 2rem;
|
| 47 |
+
border-bottom: 2px solid #005f00;
|
| 48 |
+
padding-bottom: 1rem;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
ul {
|
| 52 |
+
list-style-type: none;
|
| 53 |
+
padding: 0;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
li {
|
| 57 |
+
display: flex;
|
| 58 |
+
align-items: baseline;
|
| 59 |
+
margin-bottom: 1.5rem;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
.vibe {
|
| 63 |
+
flex-shrink: 0;
|
| 64 |
+
margin-right: 1rem;
|
| 65 |
+
font-size: 1.5rem;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
.title a {
|
| 69 |
+
color: #ffffff; /* Brighter white for main links */
|
| 70 |
+
text-decoration: none;
|
| 71 |
+
font-size: 0.8rem;
|
| 72 |
+
line-height: 1.5;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
.title a:hover {
|
| 76 |
+
background-color: #00ff00;
|
| 77 |
+
color: #1a1a1a;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
.details {
|
| 81 |
+
font-size: 0.7rem;
|
| 82 |
+
color: #8cff8c; /* Dimmer green for details */
|
| 83 |
+
margin-top: 0.5rem;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
.details a {
|
| 87 |
+
color: #00ff00;
|
| 88 |
+
text-decoration: underline;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
.details a:hover {
|
| 92 |
+
color: #ffffff;
|
| 93 |
+
background-color: transparent;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
/* Make sure code tags also use the pixel font */
|
| 97 |
+
code {
|
| 98 |
+
font-family: 'Press Start 2P', cursive;
|
| 99 |
+
background-color: #005f00;
|
| 100 |
+
padding: 2px 4px;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
</style>
|
| 104 |
+
</head>
|
| 105 |
+
<body>
|
| 106 |
+
<div class="container">
|
| 107 |
+
<h1>[ Hacker News Vibe Reader ]</h1>
|
| 108 |
+
<div class="meta-info">
|
| 109 |
+
MODEL: <code>{{ model_name }}</code> <br>
|
| 110 |
+
UPDATED: {{ last_updated }}
|
| 111 |
+
</div>
|
| 112 |
+
<ul>
|
| 113 |
+
{% for item in entries %}
|
| 114 |
+
<li>
|
| 115 |
+
<div class="vibe">{{ item.mood.status_html | safe }}</div>
|
| 116 |
+
<div>
|
| 117 |
+
<div class="title"><a href="{{ item.link }}" target="_blank">{{ item.title }}</a></div>
|
| 118 |
+
<div class="details">
|
| 119 |
+
{{ item.published_time_str }} | <a href="{{ item.comments_link }}" target="_blank">COMMENTS</a>
|
| 120 |
+
</div>
|
| 121 |
+
</div>
|
| 122 |
+
</li>
|
| 123 |
+
{% endfor %}
|
| 124 |
+
</ul>
|
| 125 |
+
</div>
|
| 126 |
+
</body>
|
| 127 |
+
</html>
|
vibe_logic.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from math import floor
|
| 3 |
+
from typing import List
|
| 4 |
+
from sentence_transformers import SentenceTransformer, util
|
| 5 |
+
|
| 6 |
+
# --- Data Structures ---
|
| 7 |
+
|
| 8 |
+
@dataclass(frozen=True)
|
| 9 |
+
class VibeThreshold:
|
| 10 |
+
"""Defines a threshold for a Vibe status."""
|
| 11 |
+
score: float
|
| 12 |
+
status: str
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class VibeResult:
|
| 16 |
+
"""Stores the calculated HSL color and status for a given score."""
|
| 17 |
+
raw_score: float
|
| 18 |
+
status_html: str # Pre-formatted HTML for display
|
| 19 |
+
color_hsl: str # Raw HSL color string
|
| 20 |
+
|
| 21 |
+
# Define the status thresholds from highest score to lowest score
|
| 22 |
+
VIBE_THRESHOLDS: List[VibeThreshold] = [
|
| 23 |
+
VibeThreshold(score=0.8, status="✨ VIBE:HIGH"),
|
| 24 |
+
VibeThreshold(score=0.5, status="👍 VIBE:GOOD"),
|
| 25 |
+
VibeThreshold(score=0.2, status="😐 VIBE:FLAT"),
|
| 26 |
+
VibeThreshold(score=0.0, status="👎 VIBE:LOW "), # Base case for scores < 0.2
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
# --- Utility Functions ---
|
| 30 |
+
|
| 31 |
+
def map_score_to_vibe(score: float) -> VibeResult:
|
| 32 |
+
"""
|
| 33 |
+
Maps a cosine similarity score to a VibeResult containing status, HTML, and color.
|
| 34 |
+
"""
|
| 35 |
+
# 1. Clamp score for safety
|
| 36 |
+
clamped_score = max(0.0, min(1.0, score))
|
| 37 |
+
|
| 38 |
+
# 2. Color Calculation
|
| 39 |
+
hue = floor(clamped_score * 120) # Linear interpolation: 0 (Red) -> 120 (Green)
|
| 40 |
+
color_hsl = f"hsl({hue}, 80%, 50%)"
|
| 41 |
+
|
| 42 |
+
# 3. Status Determination
|
| 43 |
+
status_text: str = VIBE_THRESHOLDS[-1].status # Default to the lowest status
|
| 44 |
+
for threshold in VIBE_THRESHOLDS:
|
| 45 |
+
if clamped_score >= threshold.score:
|
| 46 |
+
status_text = threshold.status
|
| 47 |
+
break
|
| 48 |
+
|
| 49 |
+
# 4. Create the pre-formatted HTML for display
|
| 50 |
+
status_html = f"<span style='color: {color_hsl}; font-weight: bold;'>{status_text}</span>"
|
| 51 |
+
|
| 52 |
+
return VibeResult(raw_score=score, status_html=status_html, color_hsl=color_hsl)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# --- Core Logic Class ---
|
| 56 |
+
|
| 57 |
+
class VibeChecker:
|
| 58 |
+
"""
|
| 59 |
+
Handles similarity scoring using a SentenceTransformer model and a pre-set anchor query.
|
| 60 |
+
"""
|
| 61 |
+
def __init__(self, model: SentenceTransformer, query_anchor: str, task_name: str):
|
| 62 |
+
self.model = model
|
| 63 |
+
self.query_anchor = query_anchor
|
| 64 |
+
self.task_name = task_name
|
| 65 |
+
|
| 66 |
+
# Pre-calculate the anchor embedding for efficiency
|
| 67 |
+
self.query_embedding = self.model.encode(
|
| 68 |
+
self.query_anchor,
|
| 69 |
+
prompt_name=self.task_name,
|
| 70 |
+
normalize_embeddings=True
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def check(self, text: str) -> VibeResult:
|
| 74 |
+
"""
|
| 75 |
+
Calculates the "vibe" of a given text against the pre-configured anchor.
|
| 76 |
+
"""
|
| 77 |
+
title_embedding = self.model.encode(
|
| 78 |
+
text,
|
| 79 |
+
prompt_name=self.task_name,
|
| 80 |
+
normalize_embeddings=True
|
| 81 |
+
)
|
| 82 |
+
# Use dot product for similarity with normalized embeddings
|
| 83 |
+
score: float = util.dot_score(self.query_embedding, title_embedding).item()
|
| 84 |
+
|
| 85 |
+
return map_score_to_vibe(score)
|