Spaces:
Running
Running
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| #import base64 | |
| import os | |
| #from google import genai | |
| #from google.genai import types | |
| from sentence_transformers import SentenceTransformer | |
| from dotenv import load_dotenv | |
| import numpy as np | |
| import random | |
| from PIL import Image | |
| import io | |
| load_dotenv() | |
| """ | |
| For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference | |
| """ | |
| def return_image(artist): | |
| if artist == "Radiohead": | |
| return random.choice(["radiohead.png", "radiohead2.png",]) | |
| if artist == "Kendrick Lamar": | |
| return random.choice(['kendrick3.png','kendrick3.png',"kendrick2.png", "kendrick4.png"]) | |
| if artist == "Grateful Dead": | |
| return random.choice(["bob.png", "bob2.png", "jerry.png"]) # Randomly select between bob.png and bob2.png | |
| if artist == "Google Gemma": | |
| return "gemma.png" | |
| return "radiohead.png" | |
| #def find_most_relevant_lyric(lyrics, user_input): | |
| # user_doc = nlp(user_input) | |
| # best_match = max(lyrics, key=lambda lyric: user_doc.similarity(nlp(lyric))) | |
| # return best_match | |
| # | |
| #def stitch_lyrics(lyrics, line_number=1): | |
| # return [lyrics[i] + " " + lyrics[i + line_number] for i in range(len(lyrics) - line_number)] | |
| # Load lyrics from a text file | |
| def load_lyrics(filename): | |
| with open(filename, "r", encoding="utf-8") as file: | |
| lyrics = file.readlines() | |
| return [line for line in lyrics] | |
| #return [line for line in lyrics] | |
| def songs_from_text(lines): | |
| songs = [] | |
| current_song = [] | |
| current_stanza = [] | |
| for line in lines: | |
| line = line.strip() | |
| if line == "==================================================": # New song delimiter | |
| if current_stanza: | |
| current_song.append(current_stanza) | |
| current_stanza = [] | |
| if current_song: | |
| songs.append(current_song) | |
| current_song = [] | |
| continue | |
| if line == "": # New stanza delimiter | |
| if current_stanza: | |
| current_song.append(current_stanza) | |
| current_stanza = [] | |
| continue | |
| current_stanza.append(line) | |
| if current_stanza: | |
| current_song.append(current_stanza) | |
| if current_song: | |
| songs.append(current_song) | |
| return songs | |
| def generate_cumulative_phrases(songs): | |
| all_phrases = [] | |
| for song in songs: | |
| for stanza in song: | |
| for i in range(len(stanza)): | |
| cumulative = "" | |
| for j in range(i, min(len(stanza),4)): | |
| cumulative += (" // " if cumulative else "") + stanza[j] | |
| all_phrases.append(cumulative) | |
| return all_phrases | |
| def artist_response(gemma_response, artist): | |
| if artist == "Radiohead": | |
| artist_embeddings = radiohead_embeddings | |
| lyric_list = all_phrases_radiohead | |
| #lyric_list = stitched_radiohead_lyrics | |
| if artist == "Kendrick Lamar": | |
| artist_embeddings = kendrick_embeddings | |
| lyric_list = all_phrases_kendrick | |
| if artist == "Grateful Dead": | |
| artist_embeddings = grateful_dead_embeddings | |
| lyric_list = all_phrases_grateful_dead | |
| if artist == "Google Gemma": | |
| return gemma_response | |
| encoder = get_encoder() | |
| encoded_gemma = encoder.encode(gemma_response, precision="int8") | |
| #encoded_gemma = encoder_model.encode(gemma_response) | |
| similarity_result = cosine_similarity_int8(encoded_gemma, artist_embeddings) | |
| result_max_index = np.argmax(similarity_result) | |
| lyric_response = lyric_list[result_max_index] | |
| return lyric_response | |
| def chat_with_musician(user_input, history, artist): | |
| global artist_history | |
| if history is None: | |
| history = [] | |
| previous_artist = artist_history[-1] | |
| if artist != previous_artist: | |
| history.clear() | |
| # Convert Gradio history tuples to HF message dicts | |
| messages = [] | |
| for user_msg, bot_msg in history[-5:]: # last 5 exchanges | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": bot_msg}) | |
| # Add current user message | |
| messages.append({"role": "user", "content": system_message + "\n\n" + user_input}) | |
| try: | |
| response = client.chat.completions.create( | |
| model="zai-org/GLM-4.7-Flash", | |
| messages=messages, | |
| max_tokens=256, | |
| temperature=0.75, | |
| ) | |
| gemma_response = response.choices[0].message.content | |
| except Exception as e: | |
| gemma_response = f"Error: {str(e)}" | |
| lyric_response = artist_response(gemma_response, artist) | |
| # Check for repeated response logic (optional) | |
| if len(messages) > 1 and lyric_response == messages[-2]["content"]: | |
| messages[-1] = {"role": "user", "content": system_message_repeated + "\n\n" + user_input} | |
| try: | |
| response = client.chat.completions.create( | |
| model="zai-org/GLM-4.7-Flash", | |
| messages=messages, | |
| max_tokens=256, | |
| temperature=0.75, | |
| ) | |
| gemma_response = response.choices[0].message.content | |
| except Exception as e: | |
| gemma_response = f"Error: {str(e)}" | |
| lyric_response = artist_response(gemma_response, artist) | |
| # Append new exchange to Gradio history format | |
| history.append((user_input, lyric_response)) | |
| artist_history.append(artist) | |
| artist_history[:] = artist_history[-10:] | |
| return lyric_response | |
| def cosine_similarity_int8(query, embeddings): | |
| # query: (d,) | |
| # embeddings: (n, d) | |
| query = query.astype(np.int32) | |
| embeddings = embeddings.astype(np.int32) | |
| dots = embeddings @ query | |
| query_norm = np.linalg.norm(query) | |
| emb_norms = np.linalg.norm(embeddings, axis=1) | |
| return dots / (emb_norms * query_norm + 1e-8) | |
| HF_API_KEY = os.environ["HF_API_KEY"] | |
| _encoder_model = None | |
| def get_encoder(): | |
| global _encoder_model | |
| if _encoder_model is None: | |
| _encoder_model = SentenceTransformer('all-MiniLM-L6-v2', | |
| #'sentence-transformers/all-MiniLM-L6-v2', | |
| #backend='openvino', | |
| #model_kwargs={"file_name": "openvino/openvino_model_qint8_quantized.xml"}, | |
| #to increase speed: | |
| #similarity_function=SimilarityFunction.DOT_PRODUCT, | |
| ) | |
| return _encoder_model | |
| radiohead_embeddings = np.load("radiohead_embeddings.npy", mmap_mode="r") | |
| kendrick_embeddings = np.load("kendrick_embeddings.npy", mmap_mode="r") | |
| grateful_dead_embeddings = np.load("grateful_dead_embeddings.npy", mmap_mode="r") | |
| radiohead_lyrics = load_lyrics("radiohead_lyrics.txt") | |
| kendrick_lyrics = load_lyrics("kendrick_lamar_lyrics.txt") | |
| grateful_dead_lyrics = load_lyrics('grateful_dead_lyrics.txt') | |
| all_phrases_radiohead = generate_cumulative_phrases(songs_from_text(radiohead_lyrics)) | |
| all_phrases_kendrick = generate_cumulative_phrases(songs_from_text(kendrick_lyrics)) | |
| all_phrases_grateful_dead = generate_cumulative_phrases(songs_from_text(grateful_dead_lyrics)) | |
| size = 350 #256 | |
| # Initialize Hugging Face Inference Client | |
| client = InferenceClient(token=HF_API_KEY, | |
| #model="MiniMaxAI/MiniMax-M2.1",) | |
| ) | |
| system_message = "Don't be too repetitive. Please limit your response to only a few sentences." | |
| artist_history = [""] # If you want to track previous artist selection | |
| # Size for the image thumbnail (set your size) | |
| size = 350 #150 | |
| def respond(message, artist, chat_history): | |
| if not message: | |
| return chat_history | |
| reply = f"Echo ({artist}): {message}" | |
| chat_history = chat_history or [] | |
| chat_history.append((message, reply)) | |
| return chat_history | |
| def chatbot_response(message, artist, chat_history): | |
| global artist_history | |
| if message is None or message.strip() == "": | |
| return chat_history or [] | |
| response = chat_with_musician(message, chat_history or [], artist) | |
| chat_history = chat_history or [] | |
| chat_history.append((message, response)) | |
| return chat_history | |
| def update_artist_image(artist): | |
| # Call your existing function to get the image path or PIL.Image | |
| return return_image(artist) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| artist_dropdown = gr.Dropdown( | |
| choices=["Radiohead", "Kendrick Lamar", "Grateful Dead", "Google Gemma"], | |
| value="Radiohead", | |
| label="Select artist", | |
| interactive=True, | |
| ) | |
| artist_image = gr.Image( | |
| value=return_image("Radiohead"), | |
| label="Thumbnail", | |
| height=size, | |
| width=size, | |
| show_label=False, | |
| show_fullscreen_button=False, | |
| show_download_button=False, | |
| show_share_button=False, | |
| ) | |
| with gr.Column(scale=1): | |
| chatbot = gr.Chatbot(height=400, type='messages') | |
| message_input = gr.Textbox( | |
| label="Your message", | |
| placeholder="Enter a message and press Enter", | |
| lines=2, | |
| interactive=True, | |
| ) | |
| artist_dropdown.change(fn=update_artist_image, inputs=artist_dropdown, outputs=artist_image) | |
| message_input.submit(fn=chatbot_response, inputs=[message_input, artist_dropdown, chatbot], outputs=chatbot).then(lambda: "", None, message_input) | |
| if __name__ == "__main__": | |
| demo.launch() |