Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import re | |
| from groq import Groq | |
| from dotenv import load_dotenv | |
| import httpx | |
| from tools import tools | |
| from utils import execute_sql_query | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize Groq client | |
| groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"), http_client=httpx.Client()) | |
| print(os.getenv("GROQ_API_KEY")) | |
| def chat_with_groq(user_input): | |
| """ | |
| Processes user input using Groq API and executes SQL queries on Supabase when needed. | |
| Args: | |
| user_input (str): The user's query. | |
| Returns: | |
| str: Response from the chatbot. | |
| """ | |
| try: | |
| # Extract the number from "top X songs" or "give top X songs" if present | |
| limit = 10 # Default limit | |
| match = re.search(r'(?:top|give top) (\d+)', user_input.lower()) | |
| if match: | |
| limit = int(match.group(1)) | |
| print("hered ") | |
| response = groq_client.chat.completions.create( | |
| model="llama3-8b-8192", | |
| messages=[ | |
| {"role": "system", "content": ( | |
| "You are a helpful assistant that can query a Supabase PostgreSQL database using SQL. " | |
| "Use the execute_sql_query function only when the user explicitly asks for data from the database (e.g., 'give me songs', 'find songs', 'top songs', 'give top songs'). " | |
| "For greetings like 'hi' or 'hello', respond with a simple greeting like 'Hello! How can I help you?' without querying the database. " | |
| "The database has a 'songs' table with columns: \"Track Name\", \"Artist Name(s)\", \"Valence\", \"Popularity\", etc. " | |
| "Always use quoted column names to handle case sensitivity and special characters (e.g., \"Track Name\" with quotes). " | |
| "Ensure there is a space after each quoted column name in the SELECT clause and a space before the FROM keyword (e.g., SELECT \"Track Name\", \"Artist Name(s)\" FROM with spaces). " | |
| "The \"Artist Name(s)\" column may contain multiple artists as a comma-separated string, so use ILIKE for partial matching (e.g., \"Artist Name(s)\" ILIKE '%artist_name%'). " | |
| "For queries like 'top X songs' or 'give top X songs', extract the number X (default to 10 if not specified) and use it in the LIMIT clause. " | |
| "For generic song queries, use: SELECT \"Track Name\", \"Artist Name(s)\" FROM songs LIMIT X. " | |
| "If the user specifies a sorting criterion (e.g., 'top 10 songs by popularity'), sort by the appropriate column (e.g., ORDER BY \"Popularity\" DESC). " | |
| "Always return SELECT \"Track Name\", \"Artist Name(s)\" in the query, not SELECT *. " | |
| "Generate complete and valid JSON and SQL queries, ensuring proper escaping of quotes, correct spacing, and using ASCII characters for operators (e.g., use < and >, not \u003c or \u003e)." | |
| )}, | |
| {"role": "user", "content": user_input} | |
| ], | |
| tools=tools, | |
| tool_choice="auto", | |
| max_tokens=4096 | |
| ) | |
| print(f"Full response: {response}") # Debug the entire response | |
| choice = response.choices[0] | |
| tool_calls = getattr(choice.message, 'tool_calls', None) | |
| message_content = getattr(choice.message, 'content', None) | |
| # Handle /tool-use block in content if tool_calls is None | |
| if not tool_calls and message_content: | |
| tool_use_match = re.search( | |
| r'<tool-use>\n(.*)\n</tool-use>', message_content, re.DOTALL) | |
| if tool_use_match: | |
| tool_use_content = tool_use_match.group(1) | |
| try: | |
| tool_use_data = json.loads(tool_use_content) | |
| tool_calls = tool_use_data.get("tool_calls", []) | |
| # Convert dict to object for consistency with tool_calls structure | |
| class ToolCall: | |
| def __init__(self, d): | |
| self.__dict__ = d | |
| self.function = type('Function', (), { | |
| 'name': d['function']['name'], 'arguments': d['function']['arguments']})() | |
| tool_calls = [ToolCall(tc) for tc in tool_calls] | |
| except json.JSONDecodeError as e: | |
| print(f"Failed to parse /tool-use block: {e}") | |
| tool_calls = [] | |
| if tool_calls: | |
| for tool_call in tool_calls: | |
| if tool_call.function.name == "execute_sql_query": | |
| try: | |
| # Extract the arguments string | |
| arguments_str = tool_call.function.arguments | |
| # Debug output | |
| print(f"Raw arguments_str: {arguments_str}") | |
| # Replace Unicode characters with their ASCII equivalents | |
| arguments_str = arguments_str.replace( | |
| '\u003e', '>').replace('\u003c', '<') | |
| # Extract the sql_query value using a robust regex | |
| match = re.search( | |
| r'"sql_query":"((?:[^"\\]|\\.)*)"', arguments_str) | |
| if match: | |
| sql_query = match.group(1) | |
| # Clean inner escaped quotes | |
| sql_query = sql_query.replace('\\"', '"') | |
| # Remove any trailing semicolon | |
| sql_query = sql_query.rstrip(';') | |
| else: | |
| sql_query = "" | |
| print("Failed to extract sql_query from arguments_str") | |
| except Exception as e: | |
| return f"⚠️ Error parsing tool call arguments: {str(e)} - Raw JSON: {arguments_str}" | |
| if not sql_query: | |
| return "⚠️ No SQL query provided." | |
| # Debug: Print the extracted SQL query | |
| print(f"Extracted SQL query: {sql_query}") | |
| # Clean the SQL query to remove any remaining escape issues | |
| sql_query = sql_query.replace('\\"', '"') | |
| # Replace Unicode characters (redundant but ensures all cases are covered) | |
| sql_query = sql_query.replace( | |
| '\u003e', '>').replace('\u003c', '<') | |
| # Fix regex pattern (if any regex is used in the query) | |
| sql_query = sql_query.replace('^[0-9.]+$$', '^[0-9.]+$') | |
| # Debug: Print query before cleaning | |
| print(f"SQL query before cleaning: {sql_query}") | |
| # Ensure proper spacing in the SELECT clause | |
| # Add space after comma between quoted columns | |
| sql_query = re.sub( | |
| r'("[^"]+")\s*,\s*("[^"]+")', r'\1, \2', sql_query) | |
| # Ensure space before FROM (case-insensitive match for FROM) | |
| # Add space before FROM | |
| sql_query = re.sub( | |
| r'("[^"]+")(?i)(FROM)', r'\1 FROM', sql_query) | |
| # Debug: Print query after cleaning | |
| print(f"Cleaned SQL query: {sql_query}") | |
| # Basic SQL syntax check | |
| if not sql_query.strip().upper().startswith("SELECT"): | |
| return f"⚠️ Invalid SQL query: {sql_query}" | |
| # Debug: Print final query before execution | |
| print(f"Final SQL query before execution: {sql_query}") | |
| # Execute the SQL query | |
| print(f"Executing SQL Query: {sql_query}") | |
| result = execute_sql_query(sql_query) | |
| if isinstance(result, list): | |
| if result: | |
| formatted_result = f"Top {min(len(result), limit)} Songs:\n" | |
| # Limit to requested or available songs | |
| for i, row in enumerate(result[:limit], 1): | |
| track_name = row.get( | |
| "Track Name", "Unknown Track") | |
| artist_names = row.get( | |
| "Artist Name(s)", "Unknown Artist") | |
| formatted_result += f"{i}. {track_name} by {artist_names}\n" | |
| return formatted_result.strip() | |
| else: | |
| return "🔍 No results found for the query." | |
| else: | |
| return result # Error message from execute_sql_query | |
| # Fallback for no tool calls (e.g., greetings) | |
| if message_content and not tool_calls: | |
| # Check if content is a /tool-use block with empty tool_calls | |
| if '<tool-use>' in message_content and '"tool_calls": []' in message_content: | |
| return "Hello! How can I help you?" | |
| return message_content.strip() | |
| else: | |
| return "I'm sorry, I couldn't process your request. (No message content or tool calls found)" | |
| except Exception as e: | |
| print(e) | |
| return f"Error: {str(e)}" |