|
|
from dataclasses import dataclass |
|
|
from math import floor |
|
|
from typing import List |
|
|
from sentence_transformers import SentenceTransformer, util |
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class VibeThreshold: |
|
|
"""Defines a threshold for a Vibe status.""" |
|
|
score: float |
|
|
status: str |
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class VibeResult: |
|
|
"""Stores the calculated HSL color and status for a given score.""" |
|
|
raw_score: float |
|
|
status_html: str |
|
|
color_hsl: str |
|
|
|
|
|
|
|
|
VIBE_THRESHOLDS: List[VibeThreshold] = [ |
|
|
VibeThreshold(score=0.8, status="β¨ VIBE:HIGH"), |
|
|
VibeThreshold(score=0.5, status="π VIBE:GOOD"), |
|
|
VibeThreshold(score=0.2, status="π VIBE:FLAT"), |
|
|
VibeThreshold(score=0.0, status="π VIBE:LOW"), |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
def map_score_to_vibe(score: float) -> VibeResult: |
|
|
""" |
|
|
Maps a cosine similarity score to a VibeResult containing status, HTML, and color. |
|
|
""" |
|
|
|
|
|
clamped_score = max(0.0, min(1.0, score)) |
|
|
|
|
|
|
|
|
hue = floor(clamped_score * 120) |
|
|
color_hsl = f"hsl({hue}, 80%, 50%)" |
|
|
|
|
|
|
|
|
status_text: str = VIBE_THRESHOLDS[-1].status |
|
|
for threshold in VIBE_THRESHOLDS: |
|
|
if clamped_score >= threshold.score: |
|
|
status_text = threshold.status |
|
|
break |
|
|
|
|
|
|
|
|
status_html = f"<span style='color: {color_hsl}; font-weight: bold;'>{status_text}</span>" |
|
|
|
|
|
return VibeResult(raw_score=score, status_html=status_html, color_hsl=color_hsl) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VibeChecker: |
|
|
""" |
|
|
Handles similarity scoring using a SentenceTransformer model and a pre-set anchor query. |
|
|
""" |
|
|
def __init__(self, model: SentenceTransformer, query_anchor: str, task_name: str): |
|
|
self.model = model |
|
|
self.query_anchor = query_anchor |
|
|
self.task_name = task_name |
|
|
|
|
|
|
|
|
self.query_embedding = self.model.encode( |
|
|
self.query_anchor, |
|
|
prompt_name=self.task_name, |
|
|
normalize_embeddings=True |
|
|
) |
|
|
|
|
|
def check(self, text: str) -> VibeResult: |
|
|
""" |
|
|
Calculates the "vibe" of a given text against the pre-configured anchor. |
|
|
""" |
|
|
title_embedding = self.model.encode( |
|
|
text, |
|
|
prompt_name=self.task_name, |
|
|
normalize_embeddings=True |
|
|
) |
|
|
|
|
|
score: float = util.dot_score(self.query_embedding, title_embedding).item() |
|
|
|
|
|
return map_score_to_vibe(score) |
|
|
|