TutuAwad commited on
Commit
c3ed2a9
·
verified ·
1 Parent(s): c5c2b68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -10
app.py CHANGED
@@ -30,13 +30,24 @@ df_embeddings = emb_data["df_embeddings"].astype("float32")
30
  index = faiss.read_index(INDEX_PATH)
31
 
32
  # ---------- Secrets ----------
 
33
  HF_TOKEN = os.getenv("HF_TOKEN")
34
  SPOTIFY_CLIENT_ID = os.getenv("SPOTIPY_CLIENT_ID")
35
- SPOTIFY_CLIENT_SECRET = os.getenv("SPOTIPY_CLIENT_SECRET")
36
 
37
  # ---------- Models ----------
38
  query_embedder = SentenceTransformer("all-mpnet-base-v2")
39
- hf_client = InferenceClient(model="meta-llama/Llama-2-7b-chat-hf", token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
40
 
41
  sp = None
42
  if SPOTIFY_CLIENT_ID and SPOTIFY_CLIENT_SECRET:
@@ -47,17 +58,48 @@ if SPOTIFY_CLIENT_ID and SPOTIFY_CLIENT_SECRET:
47
  def encode_query(text):
48
  return query_embedder.encode([text], convert_to_numpy=True).astype("float32")
49
 
50
- def expand_with_llama(query):
51
- if not hf_client:
 
 
 
 
 
 
 
 
52
  return query
 
53
  prompt = f"""You are helping someone search a lyrics catalog.
54
- If the input looks like lyrics or a singer name, return artist and song titles that match.
55
- Otherwise, return a short list of lyric-style keywords related to the input sentence.
56
 
57
- Input: {query}
58
- Output:"""
59
- response = hf_client.text_generation(prompt, max_new_tokens=96, temperature=0.2, repetition_penalty=1.05)
60
- return query + " " + str(response).strip().replace("\n", " ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def distances_to_similarity_pct(dists):
63
  if len(dists) == 0: return np.array([])
 
30
  index = faiss.read_index(INDEX_PATH)
31
 
32
  # ---------- Secrets ----------
33
+
34
  HF_TOKEN = os.getenv("HF_TOKEN")
35
  SPOTIFY_CLIENT_ID = os.getenv("SPOTIPY_CLIENT_ID")
36
+ SPOTIPY_CLIENT_SECRET = os.getenv("SPOTIPY_CLIENT_SECRET")
37
 
38
  # ---------- Models ----------
39
  query_embedder = SentenceTransformer("all-mpnet-base-v2")
40
+
41
+ LLAMA_MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"
42
+
43
+ # Create a generic client; we'll pass model per call
44
+ hf_client = None
45
+ if HF_TOKEN:
46
+ try:
47
+ hf_client = InferenceClient(token=HF_TOKEN)
48
+ except Exception as e:
49
+ print("⚠️ Could not initialize HF Inference client:", repr(e))
50
+ hf_client = None
51
 
52
  sp = None
53
  if SPOTIFY_CLIENT_ID and SPOTIFY_CLIENT_SECRET:
 
58
  def encode_query(text):
59
  return query_embedder.encode([text], convert_to_numpy=True).astype("float32")
60
 
61
+ def expand_with_llama(query: str) -> str:
62
+ """
63
+ Enrich the query using LLaMA via HF Inference.
64
+
65
+ On HF Spaces, the Inference provider can sometimes be unavailable
66
+ or misconfigured (giving the StopIteration error you saw). In that
67
+ case, we log and fall back to the raw query so the UI keeps working.
68
+ """
69
+ if hf_client is None or not HF_TOKEN:
70
+ # No client/token -> behave like "no expansion"
71
  return query
72
+
73
  prompt = f"""You are helping someone search a lyrics catalog.
 
 
74
 
75
+ If the input looks like existing song lyrics or a singer name,
76
+ return artist and song titles that match.
77
+
78
+ Otherwise, return a short list of lyric-style keywords
79
+ that are closely related to the input sentence.
80
+
81
+ Input:
82
+ {query}
83
+
84
+ Output (no explanation, just titles or keywords):"""
85
+
86
+ try:
87
+ response = hf_client.text_generation(
88
+ prompt,
89
+ model=LLAMA_MODEL_ID,
90
+ max_new_tokens=96,
91
+ temperature=0.2,
92
+ repetition_penalty=1.05,
93
+ )
94
+ except Exception as e:
95
+
96
+ print("LLaMA expansion failed on HF, using raw query:", repr(e))
97
+ return query
98
+
99
+ keywords = str(response).strip().replace("\n", " ")
100
+ expanded = query + " " + keywords
101
+ return expanded
102
+
103
 
104
  def distances_to_similarity_pct(dists):
105
  if len(dists) == 0: return np.array([])