Spaces:
Sleeping
Sleeping
File size: 5,831 Bytes
fe846c2 d72603f fc317f1 d72603f fe846c2 1c59066 fe846c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import faiss
import pandas as pd
import numpy as np
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any
from sentence_transformers import SentenceTransformer
USE_LLM = os.environ.get("USE_LLM", "1") == "1"
LLM_MODEL_NAME = os.environ.get("LLM_MODEL", "google/flan-t5-small")
EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
TOP_K = int(os.environ.get("TOP_K", 5))
SIM_THRESHOLD_STRICT = float(os.environ.get("SIM_THRESHOLD_STRICT", 0.90))
# Carga perezosa del LLM para no romper si no se desea
_llm_pipe = None
def get_llm():
global _llm_pipe
if _llm_pipe is None and USE_LLM:
from transformers import pipeline
_llm_pipe = pipeline("text2text-generation", model=LLM_MODEL_NAME)
return _llm_pipe
def normalize_title(t: str) -> str:
return "".join(ch.lower() for ch in t.strip() if ch.isalnum() or ch.isspace())
@dataclass
class RAGConfig:
songs_csv: str
cache : str
genre_name: str = "Rock & Roll"
class SongIndex:
def __init__(self, cfg: RAGConfig):
self.cfg = cfg
self.df = self._load_dataset(cfg.songs_csv)
print(cfg.cache)
self.model = SentenceTransformer(EMBEDDING_MODEL_NAME)
self.index, self.embeddings = self._build_faiss(self.df["title"].tolist())
self.norm_to_idx = {normalize_title(t): i for i, t in enumerate(self.df["title"].tolist())}
def _load_dataset(self, path: str) -> pd.DataFrame:
if not os.path.exists(path):
# fallback mínimo
data = [
("Johnny B. Goode", "Chuck Berry"),
("Hound Dog", "Elvis Presley"),
("Tutti Frutti", "Little Richard"),
("Great Balls of Fire", "Jerry Lee Lewis"),
("Rock Around the Clock", "Bill Haley & His Comets"),
("Blue Suede Shoes", "Carl Perkins"),
("Peggy Sue", "Buddy Holly"),
("La Bamba", "Ritchie Valens"),
("Jailhouse Rock", "Elvis Presley"),
("Lucille", "Little Richard"),
("Good Golly Miss Molly", "Little Richard"),
("Long Tall Sally", "Little Richard"),
("Whole Lotta Shakin' Goin' On", "Jerry Lee Lewis"),
("Summertime Blues", "Eddie Cochran"),
("That'll Be the Day", "Buddy Holly"),
]
df = pd.DataFrame(data, columns=["title", "artist"])
else:
df = pd.read_csv(path)
if "artist" not in df.columns:
df["artist"] = ""
df = df.dropna(subset=["title"]).reset_index(drop=True)
return df
def _build_faiss(self, titles: List[str]):
embs = self.model.encode(titles, convert_to_numpy=True, normalize_embeddings=True)
dim = embs.shape[1]
index = faiss.IndexFlatIP(dim) # coseno como producto interno (embs normalizados)
index.add(embs.astype(np.float32))
return index, embs
def search(self, query_title: str, top_k: int = TOP_K) -> Tuple[np.ndarray, np.ndarray]:
q_emb = self.model.encode([query_title], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
sims, idxs = self.index.search(q_emb, top_k)
return idxs[0], sims[0]
class Responder:
def __init__(self, genre_name: str):
self.genre = genre_name
self.llm = get_llm()
def _rewrite(self, text: str) -> str:
if self.llm is None:
return text
out = self.llm(f"Reescribe cordial y breve: {text}", max_new_tokens=96)[0]["generated_text"]
return out
def yes_msg(self, user_title: str, hit_title: str, artist: str) -> str:
base = f'La canción "{hit_title}" SÍ pertenece al género {self.genre}' + (f" ({artist})." if artist else ".")
return self._rewrite(base)
def no_msg(self, user_title: str, suggestions: List[Tuple[str,str]]) -> str:
sugg_text = "; ".join([f'"{t}"' + (f" ({a})" if a else "") for t,a in suggestions])
base = f'La canción "{user_title}" NO pertenece al género {self.genre}. Algunas opciones del género: {sugg_text}.'
return self._rewrite(base)
def classify_title(song_index: SongIndex, responder: Responder, user_title: str) -> Dict[str, Any]:
norm = normalize_title(user_title)
if norm in song_index.norm_to_idx:
i = song_index.norm_to_idx[norm]
row = song_index.df.iloc[i]
return {
"belongs": True,
"matched_title": row["title"],
"artist": row.get("artist", ""),
"similarity": 1.0,
"message": responder.yes_msg(user_title, row["title"], row.get("artist", "")),
"suggestions": []
}
idxs, sims = song_index.search(user_title, top_k=max(TOP_K, 5))
best_i, best_sim = int(idxs[0]), float(sims[0])
best_row = song_index.df.iloc[best_i]
if best_sim >= SIM_THRESHOLD_STRICT:
return {
"belongs": True,
"matched_title": best_row["title"],
"artist": best_row.get("artist", ""),
"similarity": best_sim,
"message": responder.yes_msg(user_title, best_row["title"], best_row.get("artist", "")),
"suggestions": []
}
suggestions: List[Tuple[str,str]] = []
for j in range(min(TOP_K, len(idxs))):
r = song_index.df.iloc[int(idxs[j])]
if normalize_title(r["title"]) != norm:
suggestions.append((r["title"], r.get("artist","")))
return {
"belongs": False,
"matched_title": None,
"artist": None,
"similarity": best_sim,
"message": responder.no_msg(user_title, suggestions[:TOP_K]),
"suggestions": [{"title": t, "artist": a} for t,a in suggestions[:TOP_K]]
} |