rock_chat / code /rag_core.py
smitharauco's picture
update cache
1c59066
import os
import faiss
import pandas as pd
import numpy as np
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any
from sentence_transformers import SentenceTransformer
USE_LLM = os.environ.get("USE_LLM", "1") == "1"
LLM_MODEL_NAME = os.environ.get("LLM_MODEL", "google/flan-t5-small")
EMBEDDING_MODEL_NAME = os.environ.get("EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
TOP_K = int(os.environ.get("TOP_K", 5))
SIM_THRESHOLD_STRICT = float(os.environ.get("SIM_THRESHOLD_STRICT", 0.90))
# Carga perezosa del LLM para no romper si no se desea
_llm_pipe = None
def get_llm():
global _llm_pipe
if _llm_pipe is None and USE_LLM:
from transformers import pipeline
_llm_pipe = pipeline("text2text-generation", model=LLM_MODEL_NAME)
return _llm_pipe
def normalize_title(t: str) -> str:
return "".join(ch.lower() for ch in t.strip() if ch.isalnum() or ch.isspace())
@dataclass
class RAGConfig:
songs_csv: str
cache : str
genre_name: str = "Rock & Roll"
class SongIndex:
def __init__(self, cfg: RAGConfig):
self.cfg = cfg
self.df = self._load_dataset(cfg.songs_csv)
print(cfg.cache)
self.model = SentenceTransformer(EMBEDDING_MODEL_NAME)
self.index, self.embeddings = self._build_faiss(self.df["title"].tolist())
self.norm_to_idx = {normalize_title(t): i for i, t in enumerate(self.df["title"].tolist())}
def _load_dataset(self, path: str) -> pd.DataFrame:
if not os.path.exists(path):
# fallback mínimo
data = [
("Johnny B. Goode", "Chuck Berry"),
("Hound Dog", "Elvis Presley"),
("Tutti Frutti", "Little Richard"),
("Great Balls of Fire", "Jerry Lee Lewis"),
("Rock Around the Clock", "Bill Haley & His Comets"),
("Blue Suede Shoes", "Carl Perkins"),
("Peggy Sue", "Buddy Holly"),
("La Bamba", "Ritchie Valens"),
("Jailhouse Rock", "Elvis Presley"),
("Lucille", "Little Richard"),
("Good Golly Miss Molly", "Little Richard"),
("Long Tall Sally", "Little Richard"),
("Whole Lotta Shakin' Goin' On", "Jerry Lee Lewis"),
("Summertime Blues", "Eddie Cochran"),
("That'll Be the Day", "Buddy Holly"),
]
df = pd.DataFrame(data, columns=["title", "artist"])
else:
df = pd.read_csv(path)
if "artist" not in df.columns:
df["artist"] = ""
df = df.dropna(subset=["title"]).reset_index(drop=True)
return df
def _build_faiss(self, titles: List[str]):
embs = self.model.encode(titles, convert_to_numpy=True, normalize_embeddings=True)
dim = embs.shape[1]
index = faiss.IndexFlatIP(dim) # coseno como producto interno (embs normalizados)
index.add(embs.astype(np.float32))
return index, embs
def search(self, query_title: str, top_k: int = TOP_K) -> Tuple[np.ndarray, np.ndarray]:
q_emb = self.model.encode([query_title], convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
sims, idxs = self.index.search(q_emb, top_k)
return idxs[0], sims[0]
class Responder:
def __init__(self, genre_name: str):
self.genre = genre_name
self.llm = get_llm()
def _rewrite(self, text: str) -> str:
if self.llm is None:
return text
out = self.llm(f"Reescribe cordial y breve: {text}", max_new_tokens=96)[0]["generated_text"]
return out
def yes_msg(self, user_title: str, hit_title: str, artist: str) -> str:
base = f'La canción "{hit_title}" SÍ pertenece al género {self.genre}' + (f" ({artist})." if artist else ".")
return self._rewrite(base)
def no_msg(self, user_title: str, suggestions: List[Tuple[str,str]]) -> str:
sugg_text = "; ".join([f'"{t}"' + (f" ({a})" if a else "") for t,a in suggestions])
base = f'La canción "{user_title}" NO pertenece al género {self.genre}. Algunas opciones del género: {sugg_text}.'
return self._rewrite(base)
def classify_title(song_index: SongIndex, responder: Responder, user_title: str) -> Dict[str, Any]:
norm = normalize_title(user_title)
if norm in song_index.norm_to_idx:
i = song_index.norm_to_idx[norm]
row = song_index.df.iloc[i]
return {
"belongs": True,
"matched_title": row["title"],
"artist": row.get("artist", ""),
"similarity": 1.0,
"message": responder.yes_msg(user_title, row["title"], row.get("artist", "")),
"suggestions": []
}
idxs, sims = song_index.search(user_title, top_k=max(TOP_K, 5))
best_i, best_sim = int(idxs[0]), float(sims[0])
best_row = song_index.df.iloc[best_i]
if best_sim >= SIM_THRESHOLD_STRICT:
return {
"belongs": True,
"matched_title": best_row["title"],
"artist": best_row.get("artist", ""),
"similarity": best_sim,
"message": responder.yes_msg(user_title, best_row["title"], best_row.get("artist", "")),
"suggestions": []
}
suggestions: List[Tuple[str,str]] = []
for j in range(min(TOP_K, len(idxs))):
r = song_index.df.iloc[int(idxs[j])]
if normalize_title(r["title"]) != norm:
suggestions.append((r["title"], r.get("artist","")))
return {
"belongs": False,
"matched_title": None,
"artist": None,
"similarity": best_sim,
"message": responder.no_msg(user_title, suggestions[:TOP_K]),
"suggestions": [{"title": t, "artist": a} for t,a in suggestions[:TOP_K]]
}