Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| import requests | |
| from typing import Optional, List, Tuple, Dict | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from gradio_client import Client, handle_file | |
| from sentence_transformers import SentenceTransformer | |
| # --- 1. SETUP AND MODEL LOADING --- | |
| load_dotenv() | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| if not GROQ_API_KEY: | |
| raise ValueError("GROQ_API_KEY not found. It's needed for summarizing analysis results.") | |
| print("Loading models and connecting to clients...") | |
| embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| try: | |
| chatbot_client = Client("Anvit25/LLM_chatbot2") | |
| audio_client = Client("Anvit25/new_audio") | |
| vision_client = Client("Anvit25/vision-classifier") | |
| print("All models and clients loaded successfully.") | |
| except Exception as e: | |
| print(f"FATAL: Failed to connect to a Gradio client: {e}") | |
| exit() | |
| # --- 2. LOAD & PRECOMPUTE DATA FOR LOCAL SEARCH & INTENT --- | |
| # Load local image data | |
| image_list = [] | |
| try: | |
| with open("image.json", "r") as f: | |
| data = json.load(f) | |
| image_dir = Path("images") | |
| if image_dir.exists(): | |
| for page in data.get("pages", []): | |
| for img in page.get("images", []): | |
| description = img.get("description", "") | |
| if not description: continue | |
| for img_file in image_dir.iterdir(): | |
| if img_file.is_file() and description.lower() in img_file.name.lower(): | |
| image_list.append({"file": str(img_file), "description": description}) | |
| break | |
| print(f"Found {len(image_list)} local images for semantic search.") | |
| print("Precomputing embeddings for local images...") | |
| for img in image_list: | |
| img["embedding"] = embedding_model.encode(img["description"]) | |
| print("Embeddings precomputed.") | |
| else: | |
| print("Warning: 'images' directory not found.") | |
| except FileNotFoundError: | |
| print("Warning: image.json not found.") | |
| # Load intents from JSON for the new rule-based classifier | |
| try: | |
| with open("intents.json", "r") as f: | |
| intents_data = json.load(f) | |
| print("Local intent classifier phrases loaded successfully.") | |
| except FileNotFoundError: | |
| print("FATAL: intents.json not found. This file is required.") | |
| exit() | |
| # --- 3. HELPER FUNCTIONS --- | |
| def get_user_intent_local(user_query: str) -> dict: | |
| """ | |
| CORRECTED: Uses a robust rule-based check to classify user intent. | |
| This is much more reliable than the previous semantic search approach for intents. | |
| """ | |
| lower_query = user_query.lower() | |
| # Iterate through intents and their trigger phrases | |
| for intent, phrases in intents_data.items(): | |
| for phrase in phrases: | |
| if phrase.lower() in lower_query: | |
| # If a trigger phrase is found, identify the intent | |
| subject = lower_query.replace(phrase.lower(), "").strip() | |
| result = { | |
| "intent": intent, | |
| "query": subject if subject else user_query | |
| } | |
| print(f"Local Intent Classifier Result: {result}") | |
| return result | |
| # If no specific trigger phrase is found, default to a general chat | |
| result = {"intent": "chat", "query": user_query} | |
| print(f"Local Intent Classifier Result: {result}") | |
| return result | |
| def summarize_analysis_with_groq(json_result: dict, context: str) -> str: | |
| """Takes a JSON/dict result and uses Groq to create a human-readable summary.""" | |
| prompt = f""" | |
| You are a helpful assistant. Based on the following technical analysis from a specialized AI model, provide a friendly and concise summary for the user. | |
| Context: The user asked to '{context}'. | |
| AI Model's Raw JSON Output: | |
| ```json | |
| {json.dumps(json_result, indent=2)} | |
| ``` | |
| Your friendly, easy-to-understand summary: | |
| """ | |
| try: | |
| response = requests.post( | |
| "https://api.groq.com/openai/v1/chat/completions", | |
| headers={"Authorization": f"Bearer {GROQ_API_KEY}"}, | |
| json={"messages": [{"role": "user", "content": prompt}], "model": "llama-3.3-70b-versatile"}, | |
| ) | |
| response.raise_for_status() | |
| return response.json()["choices"][0]["message"]["content"] | |
| except Exception as e: | |
| print(f"Groq summary error: {e}") | |
| return f"I finished the analysis, but had trouble summarizing it. Here is the raw data:\n`{json.dumps(json_result)}`" | |
| def cosine_similarity(vec1, vec2): | |
| norm1 = np.linalg.norm(vec1) | |
| norm2 = np.linalg.norm(vec2) | |
| if norm1 == 0 or norm2 == 0: return 0.0 | |
| return np.dot(vec1, vec2) / (norm1 * norm2) | |
| def find_best_matching_image(query: str) -> Optional[dict]: | |
| if not image_list: return None | |
| query_emb = embedding_model.encode(query) | |
| best_match = max(image_list, key=lambda img: cosine_similarity(query_emb, img.get("embedding", []))) | |
| highest_score = cosine_similarity(query_emb, best_match.get("embedding", [])) | |
| if best_match and highest_score > 0.4: | |
| return best_match | |
| return None | |
| def generate_groq_narrative(user_query: str, search_result: Optional[dict]) -> str: | |
| if search_result: | |
| prompt = f"A user asked to find: '{user_query}'. You found an image described as: '{search_result['description']}'. Craft a short, friendly response." | |
| else: | |
| prompt = f"A user asked to find: '{user_query}'. You searched but couldn't find a match. Craft a short, polite response." | |
| try: | |
| response = requests.post( | |
| "https://api.groq.com/openai/v1/chat/completions", | |
| headers={"Authorization": f"Bearer {GROQ_API_KEY}"}, | |
| json={"messages": [{"role": "user", "content": prompt}], "model": "llama3-8b-8192"}, | |
| ) | |
| response.raise_for_status() | |
| return response.json()["choices"][0]["message"]["content"] | |
| except Exception as e: | |
| print(f"Groq narrative error: {e}") | |
| return "I had an issue describing the search result." | |
| # --- 4. CORE GRADIO LOGIC --- | |
| def handle_image_analysis(file_path: str) -> str: | |
| try: | |
| vision_result = vision_client.predict(image=handle_file(file_path), api_name="/predict") | |
| summary = summarize_analysis_with_groq(vision_result, "Analyze this image") | |
| return summary | |
| except Exception as e: | |
| return f"Sorry, I couldn't analyze the image. Error: {e}" | |
| def handle_audio_analysis(file_path: str) -> str: | |
| try: | |
| prediction_text, _ = audio_client.predict(audio_filepath=handle_file(file_path), api_name="/predict") | |
| return f"The audio analysis result is: **{prediction_text}**" | |
| except Exception as e: | |
| return f"Sorry, I couldn't analyze the audio. Error: {e}" | |
| def chat_interface(user_input: dict, history: List[Tuple[str, str]]): | |
| user_text = user_input["text"].strip() | |
| user_files = user_input["files"] | |
| new_history = history or [] | |
| # Priority 1: Handle file uploads | |
| if user_files: | |
| file_path = user_files[0] | |
| new_history.append(((file_path,), None)) | |
| if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')): | |
| bot_message = handle_image_analysis(file_path) | |
| elif file_path.lower().endswith(('.wav', '.mp3', '.flac')): | |
| bot_message = handle_audio_analysis(file_path) | |
| else: | |
| bot_message = "I'm not sure how to handle that file type." | |
| new_history[-1] = (new_history[-1][0], bot_message) | |
| return new_history, None | |
| # Priority 2: Handle text-only queries | |
| if not user_text: | |
| return new_history, None | |
| new_history.append((user_text, None)) | |
| intent_data = get_user_intent_local(user_text) # Use the NEW, robust classifier | |
| intent = intent_data.get("intent") | |
| query_subject = intent_data.get("query") | |
| if intent == "chat": | |
| try: | |
| prediction = chatbot_client.predict(user_input=query_subject, api_name="/chatbot_response") | |
| bot_message = prediction[-1]['content'] | |
| except Exception as e: | |
| print(f"Error calling chatbot client: {e}") | |
| bot_message = "I'm sorry, I'm having trouble connecting to my chat brain right now. Please try again." | |
| elif intent == "search_local_image": | |
| found_image = find_best_matching_image(query_subject) | |
| bot_message = generate_groq_narrative(query_subject, found_image) | |
| new_history[-1] = (user_text, bot_message) | |
| if found_image: | |
| new_history.append((None, (found_image['file'],))) | |
| return new_history, None | |
| elif intent == "request_image_analysis": | |
| bot_message = "Of course. Please upload the image you want me to analyze." | |
| elif intent == "request_audio_analysis": | |
| bot_message = "I'm ready. Please upload the audio file for analysis." | |
| else: | |
| bot_message = "I'm not sure how to handle that. Can you rephrase?" | |
| new_history[-1] = (user_text, bot_message) | |
| return new_history, None | |
| # --- 5. GRADIO UI DEFINITION --- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Multi-Modal AI Chatbot") as demo: | |
| gr.Markdown("# Multi-Modal AI Chatbot") | |
| gr.Markdown("I can chat, search for local images, or analyze images and audio you upload.") | |
| chatbot_history = gr.Chatbot(height=600, show_copy_button=True, layout="bubble", render=False) | |
| with gr.Row(): | |
| multimodal_textbox = gr.MultimodalTextbox( | |
| file_types=["image", "audio"], | |
| placeholder="Type your message or upload a file...", | |
| submit_btn="Send", | |
| render=False, | |
| autofocus=True | |
| ) | |
| chatbot_history.render() | |
| multimodal_textbox.render() | |
| multimodal_textbox.submit( | |
| fn=chat_interface, | |
| inputs=[multimodal_textbox, chatbot_history], | |
| outputs=[chatbot_history, multimodal_textbox] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |