File size: 9,686 Bytes
2dce42d
 
f7c3f67
7bcd422
3c5a859
 
5ab3495
7bcd422
84c55de
075046a
a5299e8
 
84c55de
7bcd422
 
2dce42d
fe2a201
 
 
 
0790f90
 
b3a1fb5
0790f90
b3a1fb5
075046a
b3a1fb5
0790f90
a256478
14a152e
0790f90
c6204c3
 
 
 
 
 
 
fe2a201
 
 
 
 
3d4a5a4
fe2a201
3d4a5a4
 
 
 
fe2a201
3d4a5a4
 
 
8ca8515
3d4a5a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d21f5
3d4a5a4
 
 
 
075046a
 
 
ee86cc5
7247f09
075046a
 
ee86cc5
075046a
 
ee86cc5
fe45956
 
c460746
 
 
 
075046a
d7ed08f
e8eaf4f
075046a
 
0cf4f3e
98135d4
ee86cc5
 
0cf4f3e
53080de
f1253c3
6e3f9a7
 
ee86cc5
6e3f9a7
 
 
ee86cc5
6e3f9a7
ee86cc5
6e3f9a7
ee86cc5
3a39aa9
 
ee86cc5
 
c92242a
ee86cc5
0a2cd85
ee86cc5
6e3f9a7
 
72ec3af
6e3f9a7
 
 
 
 
3a39aa9
 
6e3f9a7
 
 
 
0a2cd85
6e3f9a7
 
f817407
72ec3af
6e3f9a7
 
 
 
9f3d759
ee86cc5
c460746
 
 
 
 
 
 
 
 
 
 
 
e165667
72ec3af
a82a292
3d4a5a4
c460746
 
 
 
 
 
3b7cc1c
c460746
 
3b7cc1c
 
84c55de
c460746
3d4a5a4
c460746
 
 
075046a
2b5db0a
 
 
 
 
 
 
 
84c55de
fafa9ed
84c55de
aea35bf
1d30282
b049feb
 
e165667
0a86385
 
 
 
 
 
ad01535
0a86385
18cf024
ad01535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18cf024
e165667
 
 
0a86385
e165667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad01535
18cf024
75bdde0
0a86385
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import gradio as gr
from huggingface_hub import InferenceClient
#import base64
import os
#from google import genai
#from google.genai import types
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
import numpy as np
import random
from PIL import Image
import io


load_dotenv()

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""

def return_image(artist):
    if artist == "Radiohead":
        return random.choice(["radiohead.png", "radiohead2.png",])
    if artist == "Kendrick Lamar":
        return random.choice(['kendrick3.png','kendrick3.png',"kendrick2.png", "kendrick4.png"])
    if artist == "Grateful Dead":
        return random.choice(["bob.png", "bob2.png", "jerry.png"])  # Randomly select between bob.png and bob2.png
    if artist == "Google Gemma":
        return "gemma.png"
    return "radiohead.png"
    
#def find_most_relevant_lyric(lyrics, user_input):
#    user_doc = nlp(user_input)
#    best_match = max(lyrics, key=lambda lyric: user_doc.similarity(nlp(lyric)))
#    return best_match
#
#def stitch_lyrics(lyrics, line_number=1):
#    return [lyrics[i] + " " + lyrics[i + line_number] for i in range(len(lyrics) - line_number)]

# Load lyrics from a text file
def load_lyrics(filename):
    with open(filename, "r", encoding="utf-8") as file:
        lyrics = file.readlines()
        return [line for line in lyrics]
    #return [line for line in lyrics]
def songs_from_text(lines):
    songs = []
    current_song = []
    current_stanza = []

    for line in lines:
        line = line.strip()

        if line ==  "==================================================":  # New song delimiter
            if current_stanza:
                current_song.append(current_stanza)
                current_stanza = []
            if current_song:
                songs.append(current_song)
                current_song = []
            continue

        if line == "":  # New stanza delimiter
            if current_stanza:
                current_song.append(current_stanza)
                current_stanza = []
            continue

        current_stanza.append(line)

    if current_stanza:
        current_song.append(current_stanza)
    if current_song:
        songs.append(current_song)

    return songs


def generate_cumulative_phrases(songs):
    all_phrases = []

    for song in songs:
        for stanza in song:
            for i in range(len(stanza)):
                cumulative = ""
                for j in range(i, min(len(stanza),4)):
                    cumulative += (" // " if cumulative else "") + stanza[j]
                    all_phrases.append(cumulative)

    return all_phrases
def artist_response(gemma_response, artist):
    if artist == "Radiohead":
        artist_embeddings = radiohead_embeddings
        lyric_list = all_phrases_radiohead
        #lyric_list = stitched_radiohead_lyrics
    if artist == "Kendrick Lamar":
        artist_embeddings = kendrick_embeddings
        lyric_list = all_phrases_kendrick
    if artist == "Grateful Dead":
        artist_embeddings = grateful_dead_embeddings
        lyric_list = all_phrases_grateful_dead
    if artist == "Google Gemma":
        return gemma_response
    encoder = get_encoder()
    encoded_gemma = encoder.encode(gemma_response, precision="int8")
    #encoded_gemma = encoder_model.encode(gemma_response)
    similarity_result = cosine_similarity_int8(encoded_gemma, artist_embeddings)
    result_max_index = np.argmax(similarity_result)
    lyric_response = lyric_list[result_max_index]

    return lyric_response

def chat_with_musician(user_input, history, artist):
    global artist_history
    if history is None:
        history = []
    previous_artist = artist_history[-1]
    if artist != previous_artist:
        history.clear()

    # Convert Gradio history tuples to HF message dicts
    messages = []
    for user_msg, bot_msg in history[-5:]:  # last 5 exchanges
        messages.append({"role": "user", "content": user_msg})
        messages.append({"role": "assistant", "content": bot_msg})

    # Add current user message
    messages.append({"role": "user", "content": system_message + "\n\n" + user_input})

    try:
        response = client.chat.completions.create(
            model="zai-org/GLM-4.7-Flash",
            messages=messages,
            max_tokens=256,
            temperature=0.75,
        )
        gemma_response = response.choices[0].message.content
    except Exception as e:
        gemma_response = f"Error: {str(e)}"

    lyric_response = artist_response(gemma_response, artist)

    # Check for repeated response logic (optional)
    if len(messages) > 1 and lyric_response == messages[-2]["content"]:
        messages[-1] = {"role": "user", "content": system_message_repeated + "\n\n" + user_input}
        try:
            response = client.chat.completions.create(
                model="zai-org/GLM-4.7-Flash",
                messages=messages,
                max_tokens=256,
                temperature=0.75,
            )
            gemma_response = response.choices[0].message.content
        except Exception as e:
            gemma_response = f"Error: {str(e)}"
        lyric_response = artist_response(gemma_response, artist)

    # Append new exchange to Gradio history format
    history.append((user_input, lyric_response))
    artist_history.append(artist)
    artist_history[:] = artist_history[-10:]
    return lyric_response

def cosine_similarity_int8(query, embeddings):
    # query: (d,)
    # embeddings: (n, d)
    query = query.astype(np.int32)
    embeddings = embeddings.astype(np.int32)

    dots = embeddings @ query
    query_norm = np.linalg.norm(query)
    emb_norms = np.linalg.norm(embeddings, axis=1)

    return dots / (emb_norms * query_norm + 1e-8)

HF_API_KEY = os.environ["HF_API_KEY"]



_encoder_model = None

def get_encoder():
    global _encoder_model
    if _encoder_model is None:
        _encoder_model = SentenceTransformer('all-MiniLM-L6-v2',
                                    #'sentence-transformers/all-MiniLM-L6-v2', 
                                    #backend='openvino',
                                    #model_kwargs={"file_name": "openvino/openvino_model_qint8_quantized.xml"},
                                    #to increase speed:
                                    #similarity_function=SimilarityFunction.DOT_PRODUCT,
                                    )
    return _encoder_model

radiohead_embeddings = np.load("radiohead_embeddings.npy", mmap_mode="r")
kendrick_embeddings = np.load("kendrick_embeddings.npy", mmap_mode="r")
grateful_dead_embeddings = np.load("grateful_dead_embeddings.npy", mmap_mode="r")

radiohead_lyrics = load_lyrics("radiohead_lyrics.txt")
kendrick_lyrics = load_lyrics("kendrick_lamar_lyrics.txt")
grateful_dead_lyrics = load_lyrics('grateful_dead_lyrics.txt')

all_phrases_radiohead = generate_cumulative_phrases(songs_from_text(radiohead_lyrics))
all_phrases_kendrick = generate_cumulative_phrases(songs_from_text(kendrick_lyrics))
all_phrases_grateful_dead = generate_cumulative_phrases(songs_from_text(grateful_dead_lyrics))


size = 350 #256

# Initialize Hugging Face Inference Client
client = InferenceClient(token=HF_API_KEY,
                             #model="MiniMaxAI/MiniMax-M2.1",)
                             )


system_message = "Don't be too repetitive. Please limit your response to only a few sentences."

artist_history = [""]  # If you want to track previous artist selection

# Size for the image thumbnail (set your size)
size = 350 #150


def respond(message, artist, chat_history):
    if not message:
        return chat_history
    reply = f"Echo ({artist}): {message}"
    chat_history = chat_history or []
    chat_history.append((message, reply))
    return chat_history

def chatbot_response(message, artist, chat_history):
    global artist_history
    if message is None or message.strip() == "":
        return chat_history or []
    response = chat_with_musician(message, chat_history or [], artist)
    chat_history = chat_history or []
    chat_history.append((message, response))
    return chat_history

def update_artist_image(artist):
    # Call your existing function to get the image path or PIL.Image
    return return_image(artist)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=1):
            artist_dropdown = gr.Dropdown(
                choices=["Radiohead", "Kendrick Lamar", "Grateful Dead", "Google Gemma"],
                value="Radiohead",
                label="Select artist",
                interactive=True,
            )
            artist_image = gr.Image(
                value=return_image("Radiohead"),
                label="Thumbnail",
                height=size,
                width=size,
                show_label=False,
                show_fullscreen_button=False,
                show_download_button=False,
                show_share_button=False,
            )
        with gr.Column(scale=1):
            chatbot = gr.Chatbot(height=400, type='messages')
            message_input = gr.Textbox(
                label="Your message",
                placeholder="Enter a message and press Enter",
                lines=2,
                interactive=True,
            )

    artist_dropdown.change(fn=update_artist_image, inputs=artist_dropdown, outputs=artist_image)
    message_input.submit(fn=chatbot_response, inputs=[message_input, artist_dropdown, chatbot], outputs=chatbot).then(lambda: "", None, message_input)

if __name__ == "__main__":
    demo.launch()