Spaces:
Sleeping
Sleeping
Rob Learsch commited on
Commit ·
fe2a201
1
Parent(s): 687d4b6
Update app.py
Browse filesBack to working code, no thanks to Claude
app.py
CHANGED
|
@@ -1,72 +1,127 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from huggingface_hub import InferenceClient
|
|
|
|
| 3 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from dotenv import load_dotenv
|
| 5 |
|
| 6 |
load_dotenv()
|
| 7 |
HF_API_KEY = os.environ["HF_API_KEY"]
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
# Initialize Hugging Face Inference Client
|
| 10 |
client = InferenceClient(model="google/gemma-2-2b-it", token=HF_API_KEY)
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
# Return immediately to update the UI with the user's message
|
| 22 |
-
return "", history + [[message, None]]
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
# Make a simple API call with just the current message
|
| 30 |
-
response = client.chat_completion(
|
| 31 |
-
messages=[{"role": "user", "content": user_message}],
|
| 32 |
-
model="google/gemma-2-2b-it",
|
| 33 |
-
max_tokens=256,
|
| 34 |
-
temperature=0.7
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
# Print the response structure to help diagnose issues
|
| 38 |
-
print(f"API response: {response}")
|
| 39 |
-
|
| 40 |
-
# Try to extract the bot message - handle different response formats
|
| 41 |
-
bot_message = None
|
| 42 |
-
if isinstance(response, dict):
|
| 43 |
-
if "choices" in response and len(response["choices"]) > 0:
|
| 44 |
-
choice = response["choices"][0]
|
| 45 |
-
if isinstance(choice, dict) and "message" in choice:
|
| 46 |
-
message = choice["message"]
|
| 47 |
-
if isinstance(message, dict) and "content" in message:
|
| 48 |
-
bot_message = message["content"]
|
| 49 |
-
|
| 50 |
-
# If we couldn't extract a message, use a fallback
|
| 51 |
-
if bot_message is None:
|
| 52 |
-
bot_message = "Sorry, I couldn't generate a proper response."
|
| 53 |
-
|
| 54 |
-
# Update the last history item with the bot's response
|
| 55 |
-
history[-1][1] = bot_message
|
| 56 |
-
|
| 57 |
-
except Exception as e:
|
| 58 |
-
print(f"Error: {str(e)}")
|
| 59 |
-
history[-1][1] = f"Error: {str(e)}"
|
| 60 |
-
|
| 61 |
-
return history
|
| 62 |
|
| 63 |
-
#
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
)
|
|
|
|
| 67 |
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
# Launch the
|
| 71 |
if __name__ == "__main__":
|
| 72 |
-
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from huggingface_hub import InferenceClient
|
| 3 |
+
import base64
|
| 4 |
import os
|
| 5 |
+
from google import genai
|
| 6 |
+
from google.genai import types
|
| 7 |
+
import spacy
|
| 8 |
+
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
|
| 11 |
load_dotenv()
|
| 12 |
HF_API_KEY = os.environ["HF_API_KEY"]
|
| 13 |
|
| 14 |
+
"""
|
| 15 |
+
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
| 16 |
+
"""
|
| 17 |
+
nlp = spacy.load("en_core_web_md")
|
| 18 |
+
# !python -m spacy download en_core_web_md
|
| 19 |
+
|
| 20 |
+
def find_most_relevant_lyric(lyrics, user_input):
|
| 21 |
+
user_doc = nlp(user_input)
|
| 22 |
+
best_match = max(lyrics, key=lambda lyric: user_doc.similarity(nlp(lyric)))
|
| 23 |
+
return best_match
|
| 24 |
+
|
| 25 |
+
def stitch_lyrics(lyrics):
|
| 26 |
+
return [lyrics[i] + " " + lyrics[i + line_number] for line_number in range(1,5) for i in range(len(lyrics) - line_number)]
|
| 27 |
+
|
| 28 |
+
# Load lyrics from a text file
|
| 29 |
+
def load_lyrics(filename):
|
| 30 |
+
with open(filename, "r", encoding="utf-8") as file:
|
| 31 |
+
lyrics = file.readlines()
|
| 32 |
+
return [line.strip() for line in lyrics if line.strip()] # Remove empty lines
|
| 33 |
+
#return [line for line in lyrics]
|
| 34 |
+
|
| 35 |
+
lyrics = load_lyrics("the_bends_lyrics.txt")
|
| 36 |
+
stitched_lyrics = stitch_lyrics(lyrics)
|
| 37 |
+
|
| 38 |
# Initialize Hugging Face Inference Client
|
| 39 |
client = InferenceClient(model="google/gemma-2-2b-it", token=HF_API_KEY)
|
| 40 |
|
| 41 |
+
system_message = "Please limit responses to only a few sentences."
|
| 42 |
+
# Function to generate chatbot responses
|
| 43 |
+
def chat_with_musician(user_input, history, artist):
|
| 44 |
+
if history is None:
|
| 45 |
+
history = []
|
| 46 |
+
messages = []
|
| 47 |
+
#messages.append({"role": "user", "content": system_message})
|
| 48 |
|
| 49 |
+
#for pair in history[-5:]: # Keep only the last 5 exchanges
|
| 50 |
+
# if len(pair) == 2: # Only process valid (user, bot) pairs
|
| 51 |
+
# user_msg, bot_msg = pair
|
| 52 |
+
# messages.append({"role": "user", "content": user_msg})
|
| 53 |
+
# messages.append({"role": "assistant", "content": bot_msg})
|
| 54 |
+
#for pair in history[-5:]: # Keep only the last 5 exchanges
|
| 55 |
+
# if len(pair) == 2: # Only process valid (user, bot) pairs
|
| 56 |
+
# user_msg, bot_msg = pair
|
| 57 |
+
# messages.append({"role": "user", "content": user_msg})
|
| 58 |
+
# messages.append({"role": "assistant", "content": bot_msg})
|
| 59 |
+
# Add the latest user message
|
| 60 |
+
messages.append({"role": "user", "content": system_message + "\n\n" + user_input})
|
| 61 |
+
try:
|
| 62 |
+
response = client.chat_completion(
|
| 63 |
+
messages=messages,
|
| 64 |
+
model="google/gemma-2-2b-it",
|
| 65 |
+
max_tokens=256,
|
| 66 |
+
temperature=0.7,
|
| 67 |
+
top_p=0.9
|
| 68 |
+
)
|
| 69 |
+
gemma_response= response["choices"][0]["message"]["content"]
|
| 70 |
+
# history.append({"role": "user", "content": user_msg})
|
| 71 |
+
# history.append({"role": "assistant", "content": gemma_response})
|
| 72 |
+
except Exception as e:
|
| 73 |
+
return f"Error: {str(e)}"
|
| 74 |
+
if artist == "Radiohead":
|
| 75 |
+
lyric_response = find_most_relevant_lyric(stitched_lyrics,
|
| 76 |
+
gemma_response)
|
| 77 |
+
if artist == "Google Gemma":
|
| 78 |
+
lyric_response = gemma_response
|
| 79 |
+
return lyric_response, history
|
| 80 |
|
| 81 |
+
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
def chat_with_gemma(user_input, history):
|
| 84 |
+
if history is None:
|
| 85 |
+
history = []
|
| 86 |
+
messages = []
|
| 87 |
+
#messages.append({"role": "user", "content": system_message})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
+
#for pair in history[-5:]: # Keep only the last 5 exchanges
|
| 90 |
+
# if len(pair) == 2: # Only process valid (user, bot) pairs
|
| 91 |
+
# user_msg, bot_msg = pair
|
| 92 |
+
# messages.append({"role": "user", "content": user_msg})
|
| 93 |
+
# messages.append({"role": "assistant", "content": bot_msg})
|
| 94 |
|
| 95 |
+
# Add the latest user message
|
| 96 |
+
messages.append({"role": "user", "content": system_message + "\n\n" + user_input})
|
| 97 |
+
try:
|
| 98 |
+
response = client.chat_completion(
|
| 99 |
+
messages=messages,
|
| 100 |
+
model="google/gemma-2-2b-it",
|
| 101 |
+
max_tokens=512,
|
| 102 |
+
temperature=1,
|
| 103 |
+
top_p=0.75
|
| 104 |
+
)
|
| 105 |
+
bot_response = ["choices"][0]["message"]["content"]
|
| 106 |
+
history.append([user_input, bot_response])
|
| 107 |
+
return bot_response, history
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
return f"Error: {str(e)}"
|
| 111 |
+
|
| 112 |
+
# Gradio Chat Interface
|
| 113 |
+
demo = gr.ChatInterface(
|
| 114 |
+
fn = chat_with_musician,
|
| 115 |
+
type="messages",
|
| 116 |
+
additional_inputs=[
|
| 117 |
+
gr.Dropdown(choices=["Radiohead", "Google Gemma"],
|
| 118 |
+
value="Radiohead",
|
| 119 |
+
label="Select artist",
|
| 120 |
+
info="More coming soon"),
|
| 121 |
+
],
|
| 122 |
+
title="Lyrical Chatbot",
|
| 123 |
+
)
|
| 124 |
|
| 125 |
+
# Launch the chatbot
|
| 126 |
if __name__ == "__main__":
|
| 127 |
+
demo.launch()
|