File size: 5,831 Bytes
fe846c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d72603f
fc317f1
d72603f
fe846c2
 
 
 
1c59066
 
fe846c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import faiss
import pandas as pd
import numpy as np
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any

from sentence_transformers import SentenceTransformer

USE_LLM = os.environ.get("USE_LLM", "1") == "1"
LLM_MODEL_NAME = os.environ.get("LLM_MODEL", "google/flan-t5-small")
EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
TOP_K = int(os.environ.get("TOP_K", 5))
SIM_THRESHOLD_STRICT = float(os.environ.get("SIM_THRESHOLD_STRICT", 0.90))

# Carga perezosa del LLM para no romper si no se desea
_llm_pipe = None
def get_llm():
    global _llm_pipe
    if _llm_pipe is None and USE_LLM:
        from transformers import pipeline
        _llm_pipe = pipeline("text2text-generation", model=LLM_MODEL_NAME)
    return _llm_pipe

def normalize_title(t: str) -> str:
    return "".join(ch.lower() for ch in t.strip() if ch.isalnum() or ch.isspace())

@dataclass
class RAGConfig:
    songs_csv: str
    cache : str
    genre_name: str = "Rock & Roll"
    
class SongIndex:
    def __init__(self, cfg: RAGConfig):
        self.cfg = cfg
        self.df = self._load_dataset(cfg.songs_csv)
        print(cfg.cache)
        self.model = SentenceTransformer(EMBEDDING_MODEL_NAME)
        self.index, self.embeddings = self._build_faiss(self.df["title"].tolist())
        self.norm_to_idx = {normalize_title(t): i for i, t in enumerate(self.df["title"].tolist())}

    def _load_dataset(self, path: str) -> pd.DataFrame:
        if not os.path.exists(path):
            # fallback mínimo
            data = [
                ("Johnny B. Goode", "Chuck Berry"),
                ("Hound Dog", "Elvis Presley"),
                ("Tutti Frutti", "Little Richard"),
                ("Great Balls of Fire", "Jerry Lee Lewis"),
                ("Rock Around the Clock", "Bill Haley & His Comets"),
                ("Blue Suede Shoes", "Carl Perkins"),
                ("Peggy Sue", "Buddy Holly"),
                ("La Bamba", "Ritchie Valens"),
                ("Jailhouse Rock", "Elvis Presley"),
                ("Lucille", "Little Richard"),
                ("Good Golly Miss Molly", "Little Richard"),
                ("Long Tall Sally", "Little Richard"),
                ("Whole Lotta Shakin' Goin' On", "Jerry Lee Lewis"),
                ("Summertime Blues", "Eddie Cochran"),
                ("That'll Be the Day", "Buddy Holly"),
            ]
            df = pd.DataFrame(data, columns=["title", "artist"])
        else:
            df = pd.read_csv(path)
            if "artist" not in df.columns:
                df["artist"] = ""
        df = df.dropna(subset=["title"]).reset_index(drop=True)
        return df

    def _build_faiss(self, titles: List[str]):
        embs = self.model.encode(titles, convert_to_numpy=True, normalize_embeddings=True)
        dim = embs.shape[1]
        index = faiss.IndexFlatIP(dim)  # coseno como producto interno (embs normalizados)
        index.add(embs.astype(np.float32))
        return index, embs

    def search(self, query_title: str, top_k: int = TOP_K) -> Tuple[np.ndarray, np.ndarray]:
        q_emb = self.model.encode([query_title], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
        sims, idxs = self.index.search(q_emb, top_k)
        return idxs[0], sims[0]

class Responder:
    def __init__(self, genre_name: str):
        self.genre = genre_name
        self.llm = get_llm()

    def _rewrite(self, text: str) -> str:
        if self.llm is None:
            return text
        out = self.llm(f"Reescribe cordial y breve: {text}", max_new_tokens=96)[0]["generated_text"]
        return out

    def yes_msg(self, user_title: str, hit_title: str, artist: str) -> str:
        base = f'La canción "{hit_title}" SÍ pertenece al género {self.genre}' + (f" ({artist})." if artist else ".")
        return self._rewrite(base)

    def no_msg(self, user_title: str, suggestions: List[Tuple[str,str]]) -> str:
        sugg_text = "; ".join([f'"{t}"' + (f" ({a})" if a else "") for t,a in suggestions])
        base = f'La canción "{user_title}" NO pertenece al género {self.genre}. Algunas opciones del género: {sugg_text}.'
        return self._rewrite(base)

def classify_title(song_index: SongIndex, responder: Responder, user_title: str) -> Dict[str, Any]:
    norm = normalize_title(user_title)
    if norm in song_index.norm_to_idx:
        i = song_index.norm_to_idx[norm]
        row = song_index.df.iloc[i]
        return {
            "belongs": True,
            "matched_title": row["title"],
            "artist": row.get("artist", ""),
            "similarity": 1.0,
            "message": responder.yes_msg(user_title, row["title"], row.get("artist", "")),
            "suggestions": []
        }

    idxs, sims = song_index.search(user_title, top_k=max(TOP_K, 5))
    best_i, best_sim = int(idxs[0]), float(sims[0])
    best_row = song_index.df.iloc[best_i]

    if best_sim >= SIM_THRESHOLD_STRICT:
        return {
            "belongs": True,
            "matched_title": best_row["title"],
            "artist": best_row.get("artist", ""),
            "similarity": best_sim,
            "message": responder.yes_msg(user_title, best_row["title"], best_row.get("artist", "")),
            "suggestions": []
        }

    suggestions: List[Tuple[str,str]] = []
    for j in range(min(TOP_K, len(idxs))):
        r = song_index.df.iloc[int(idxs[j])]
        if normalize_title(r["title"]) != norm:
            suggestions.append((r["title"], r.get("artist","")))
    return {
        "belongs": False,
        "matched_title": None,
        "artist": None,
        "similarity": best_sim,
        "message": responder.no_msg(user_title, suggestions[:TOP_K]),
        "suggestions": [{"title": t, "artist": a} for t,a in suggestions[:TOP_K]]
    }