Spaces:
Sleeping
Sleeping
| import os | |
| import faiss | |
| import pandas as pd | |
| import numpy as np | |
| from dataclasses import dataclass | |
| from typing import List, Tuple, Dict, Any | |
| from sentence_transformers import SentenceTransformer | |
| USE_LLM = os.environ.get("USE_LLM", "1") == "1" | |
| LLM_MODEL_NAME = os.environ.get("LLM_MODEL", "google/flan-t5-small") | |
| EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| TOP_K = int(os.environ.get("TOP_K", 5)) | |
| SIM_THRESHOLD_STRICT = float(os.environ.get("SIM_THRESHOLD_STRICT", 0.90)) | |
| # Carga perezosa del LLM para no romper si no se desea | |
| _llm_pipe = None | |
| def get_llm(): | |
| global _llm_pipe | |
| if _llm_pipe is None and USE_LLM: | |
| from transformers import pipeline | |
| _llm_pipe = pipeline("text2text-generation", model=LLM_MODEL_NAME) | |
| return _llm_pipe | |
| def normalize_title(t: str) -> str: | |
| return "".join(ch.lower() for ch in t.strip() if ch.isalnum() or ch.isspace()) | |
| class RAGConfig: | |
| songs_csv: str | |
| cache : str | |
| genre_name: str = "Rock & Roll" | |
| class SongIndex: | |
| def __init__(self, cfg: RAGConfig): | |
| self.cfg = cfg | |
| self.df = self._load_dataset(cfg.songs_csv) | |
| print(cfg.cache) | |
| self.model = SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| self.index, self.embeddings = self._build_faiss(self.df["title"].tolist()) | |
| self.norm_to_idx = {normalize_title(t): i for i, t in enumerate(self.df["title"].tolist())} | |
| def _load_dataset(self, path: str) -> pd.DataFrame: | |
| if not os.path.exists(path): | |
| # fallback mínimo | |
| data = [ | |
| ("Johnny B. Goode", "Chuck Berry"), | |
| ("Hound Dog", "Elvis Presley"), | |
| ("Tutti Frutti", "Little Richard"), | |
| ("Great Balls of Fire", "Jerry Lee Lewis"), | |
| ("Rock Around the Clock", "Bill Haley & His Comets"), | |
| ("Blue Suede Shoes", "Carl Perkins"), | |
| ("Peggy Sue", "Buddy Holly"), | |
| ("La Bamba", "Ritchie Valens"), | |
| ("Jailhouse Rock", "Elvis Presley"), | |
| ("Lucille", "Little Richard"), | |
| ("Good Golly Miss Molly", "Little Richard"), | |
| ("Long Tall Sally", "Little Richard"), | |
| ("Whole Lotta Shakin' Goin' On", "Jerry Lee Lewis"), | |
| ("Summertime Blues", "Eddie Cochran"), | |
| ("That'll Be the Day", "Buddy Holly"), | |
| ] | |
| df = pd.DataFrame(data, columns=["title", "artist"]) | |
| else: | |
| df = pd.read_csv(path) | |
| if "artist" not in df.columns: | |
| df["artist"] = "" | |
| df = df.dropna(subset=["title"]).reset_index(drop=True) | |
| return df | |
| def _build_faiss(self, titles: List[str]): | |
| embs = self.model.encode(titles, convert_to_numpy=True, normalize_embeddings=True) | |
| dim = embs.shape[1] | |
| index = faiss.IndexFlatIP(dim) # coseno como producto interno (embs normalizados) | |
| index.add(embs.astype(np.float32)) | |
| return index, embs | |
| def search(self, query_title: str, top_k: int = TOP_K) -> Tuple[np.ndarray, np.ndarray]: | |
| q_emb = self.model.encode([query_title], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32) | |
| sims, idxs = self.index.search(q_emb, top_k) | |
| return idxs[0], sims[0] | |
| class Responder: | |
| def __init__(self, genre_name: str): | |
| self.genre = genre_name | |
| self.llm = get_llm() | |
| def _rewrite(self, text: str) -> str: | |
| if self.llm is None: | |
| return text | |
| out = self.llm(f"Reescribe cordial y breve: {text}", max_new_tokens=96)[0]["generated_text"] | |
| return out | |
| def yes_msg(self, user_title: str, hit_title: str, artist: str) -> str: | |
| base = f'La canción "{hit_title}" SÍ pertenece al género {self.genre}' + (f" ({artist})." if artist else ".") | |
| return self._rewrite(base) | |
| def no_msg(self, user_title: str, suggestions: List[Tuple[str,str]]) -> str: | |
| sugg_text = "; ".join([f'"{t}"' + (f" ({a})" if a else "") for t,a in suggestions]) | |
| base = f'La canción "{user_title}" NO pertenece al género {self.genre}. Algunas opciones del género: {sugg_text}.' | |
| return self._rewrite(base) | |
| def classify_title(song_index: SongIndex, responder: Responder, user_title: str) -> Dict[str, Any]: | |
| norm = normalize_title(user_title) | |
| if norm in song_index.norm_to_idx: | |
| i = song_index.norm_to_idx[norm] | |
| row = song_index.df.iloc[i] | |
| return { | |
| "belongs": True, | |
| "matched_title": row["title"], | |
| "artist": row.get("artist", ""), | |
| "similarity": 1.0, | |
| "message": responder.yes_msg(user_title, row["title"], row.get("artist", "")), | |
| "suggestions": [] | |
| } | |
| idxs, sims = song_index.search(user_title, top_k=max(TOP_K, 5)) | |
| best_i, best_sim = int(idxs[0]), float(sims[0]) | |
| best_row = song_index.df.iloc[best_i] | |
| if best_sim >= SIM_THRESHOLD_STRICT: | |
| return { | |
| "belongs": True, | |
| "matched_title": best_row["title"], | |
| "artist": best_row.get("artist", ""), | |
| "similarity": best_sim, | |
| "message": responder.yes_msg(user_title, best_row["title"], best_row.get("artist", "")), | |
| "suggestions": [] | |
| } | |
| suggestions: List[Tuple[str,str]] = [] | |
| for j in range(min(TOP_K, len(idxs))): | |
| r = song_index.df.iloc[int(idxs[j])] | |
| if normalize_title(r["title"]) != norm: | |
| suggestions.append((r["title"], r.get("artist",""))) | |
| return { | |
| "belongs": False, | |
| "matched_title": None, | |
| "artist": None, | |
| "similarity": best_sim, | |
| "message": responder.no_msg(user_title, suggestions[:TOP_K]), | |
| "suggestions": [{"title": t, "artist": a} for t,a in suggestions[:TOP_K]] | |
| } |