File size: 9,950 Bytes
8dbca95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import json
import os
from pathlib import Path
import numpy as np
import requests
from typing import Optional, List, Tuple, Dict

import gradio as gr
from dotenv import load_dotenv
from gradio_client import Client, handle_file
from sentence_transformers import SentenceTransformer

# --- 1. SETUP AND MODEL LOADING ---
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY") 

if not GROQ_API_KEY:
    raise ValueError("GROQ_API_KEY not found. It's needed for summarizing analysis results.")

print("Loading models and connecting to clients...")
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
try:
    chatbot_client = Client("Anvit25/LLM_chatbot2")
    audio_client = Client("Anvit25/new_audio")
    vision_client = Client("Anvit25/vision-classifier")
    print("All models and clients loaded successfully.")
except Exception as e:
    print(f"FATAL: Failed to connect to a Gradio client: {e}")
    exit()

# --- 2. LOAD & PRECOMPUTE DATA FOR LOCAL SEARCH & INTENT ---
# Load local image data
image_list = []
try:
    with open("image.json", "r") as f:
        data = json.load(f)
    image_dir = Path("images")
    if image_dir.exists():
        for page in data.get("pages", []):
            for img in page.get("images", []):
                description = img.get("description", "")
                if not description: continue
                for img_file in image_dir.iterdir():
                    if img_file.is_file() and description.lower() in img_file.name.lower():
                        image_list.append({"file": str(img_file), "description": description})
                        break
        print(f"Found {len(image_list)} local images for semantic search.")
        print("Precomputing embeddings for local images...")
        for img in image_list:
            img["embedding"] = embedding_model.encode(img["description"])
        print("Embeddings precomputed.")
    else:
        print("Warning: 'images' directory not found.")
except FileNotFoundError:
    print("Warning: image.json not found.")

# Load intents from JSON for the new rule-based classifier
try:
    with open("intents.json", "r") as f:
        intents_data = json.load(f)
    print("Local intent classifier phrases loaded successfully.")
except FileNotFoundError:
    print("FATAL: intents.json not found. This file is required.")
    exit()

# --- 3. HELPER FUNCTIONS ---

def get_user_intent_local(user_query: str) -> dict:
    """
    CORRECTED: Uses a robust rule-based check to classify user intent.
    This is much more reliable than the previous semantic search approach for intents.
    """
    lower_query = user_query.lower()
    
    # Iterate through intents and their trigger phrases
    for intent, phrases in intents_data.items():
        for phrase in phrases:
            if phrase.lower() in lower_query:
                # If a trigger phrase is found, identify the intent
                subject = lower_query.replace(phrase.lower(), "").strip()
                result = {
                    "intent": intent,
                    "query": subject if subject else user_query
                }
                print(f"Local Intent Classifier Result: {result}")
                return result

    # If no specific trigger phrase is found, default to a general chat
    result = {"intent": "chat", "query": user_query}
    print(f"Local Intent Classifier Result: {result}")
    return result

def summarize_analysis_with_groq(json_result: dict, context: str) -> str:
    """Takes a JSON/dict result and uses Groq to create a human-readable summary."""
    prompt = f"""
    You are a helpful assistant. Based on the following technical analysis from a specialized AI model, provide a friendly and concise summary for the user.
    Context: The user asked to '{context}'.
    AI Model's Raw JSON Output:
    ```json
    {json.dumps(json_result, indent=2)}
    ```
    Your friendly, easy-to-understand summary:
    """
    try:
        response = requests.post(
            "https://api.groq.com/openai/v1/chat/completions",
            headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
            json={"messages": [{"role": "user", "content": prompt}], "model": "llama-3.3-70b-versatile"},
        )
        response.raise_for_status()
        return response.json()["choices"][0]["message"]["content"]
    except Exception as e:
        print(f"Groq summary error: {e}")
        return f"I finished the analysis, but had trouble summarizing it. Here is the raw data:\n`{json.dumps(json_result)}`"

def cosine_similarity(vec1, vec2):
    norm1 = np.linalg.norm(vec1)
    norm2 = np.linalg.norm(vec2)
    if norm1 == 0 or norm2 == 0: return 0.0
    return np.dot(vec1, vec2) / (norm1 * norm2)

def find_best_matching_image(query: str) -> Optional[dict]:
    if not image_list: return None
    query_emb = embedding_model.encode(query)
    best_match = max(image_list, key=lambda img: cosine_similarity(query_emb, img.get("embedding", [])))
    highest_score = cosine_similarity(query_emb, best_match.get("embedding", []))
    if best_match and highest_score > 0.4:
        return best_match
    return None

def generate_groq_narrative(user_query: str, search_result: Optional[dict]) -> str:
    if search_result:
        prompt = f"A user asked to find: '{user_query}'. You found an image described as: '{search_result['description']}'. Craft a short, friendly response."
    else:
        prompt = f"A user asked to find: '{user_query}'. You searched but couldn't find a match. Craft a short, polite response."
    try:
        response = requests.post(
            "https://api.groq.com/openai/v1/chat/completions",
            headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
            json={"messages": [{"role": "user", "content": prompt}], "model": "llama3-8b-8192"},
        )
        response.raise_for_status()
        return response.json()["choices"][0]["message"]["content"]
    except Exception as e:
        print(f"Groq narrative error: {e}")
        return "I had an issue describing the search result."

# --- 4. CORE GRADIO LOGIC ---

def handle_image_analysis(file_path: str) -> str:
    try:
        vision_result = vision_client.predict(image=handle_file(file_path), api_name="/predict")
        summary = summarize_analysis_with_groq(vision_result, "Analyze this image")
        return summary
    except Exception as e:
        return f"Sorry, I couldn't analyze the image. Error: {e}"

def handle_audio_analysis(file_path: str) -> str:
    try:
        prediction_text, _ = audio_client.predict(audio_filepath=handle_file(file_path), api_name="/predict")
        return f"The audio analysis result is: **{prediction_text}**"
    except Exception as e:
        return f"Sorry, I couldn't analyze the audio. Error: {e}"

def chat_interface(user_input: dict, history: List[Tuple[str, str]]):
    user_text = user_input["text"].strip()
    user_files = user_input["files"]
    new_history = history or []
    
    # Priority 1: Handle file uploads
    if user_files:
        file_path = user_files[0]
        new_history.append(((file_path,), None))
        
        if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
            bot_message = handle_image_analysis(file_path)
        elif file_path.lower().endswith(('.wav', '.mp3', '.flac')):
            bot_message = handle_audio_analysis(file_path)
        else:
            bot_message = "I'm not sure how to handle that file type."
        
        new_history[-1] = (new_history[-1][0], bot_message)
        return new_history, None

    # Priority 2: Handle text-only queries
    if not user_text:
        return new_history, None

    new_history.append((user_text, None))
    intent_data = get_user_intent_local(user_text) # Use the NEW, robust classifier
    intent = intent_data.get("intent")
    query_subject = intent_data.get("query")

    if intent == "chat":
        try:
            prediction = chatbot_client.predict(user_input=query_subject, api_name="/chatbot_response")
            bot_message = prediction[-1]['content']
        except Exception as e:
            print(f"Error calling chatbot client: {e}")
            bot_message = "I'm sorry, I'm having trouble connecting to my chat brain right now. Please try again."

    elif intent == "search_local_image":
        found_image = find_best_matching_image(query_subject)
        bot_message = generate_groq_narrative(query_subject, found_image)
        new_history[-1] = (user_text, bot_message)
        if found_image:
             new_history.append((None, (found_image['file'],)))
        return new_history, None

    elif intent == "request_image_analysis":
        bot_message = "Of course. Please upload the image you want me to analyze."
    elif intent == "request_audio_analysis":
        bot_message = "I'm ready. Please upload the audio file for analysis."
    else:
        bot_message = "I'm not sure how to handle that. Can you rephrase?"

    new_history[-1] = (user_text, bot_message)
    return new_history, None


# --- 5. GRADIO UI DEFINITION ---
with gr.Blocks(theme=gr.themes.Soft(), title="Multi-Modal AI Chatbot") as demo:
    gr.Markdown("# Multi-Modal AI Chatbot")
    gr.Markdown("I can chat, search for local images, or analyze images and audio you upload.")

    chatbot_history = gr.Chatbot(height=600, show_copy_button=True, layout="bubble", render=False)
    
    with gr.Row():
        multimodal_textbox = gr.MultimodalTextbox(
            file_types=["image", "audio"],
            placeholder="Type your message or upload a file...",
            submit_btn="Send",
            render=False,
            autofocus=True
        )

    chatbot_history.render()
    multimodal_textbox.render()

    multimodal_textbox.submit(
        fn=chat_interface,
        inputs=[multimodal_textbox, chatbot_history],
        outputs=[chatbot_history, multimodal_textbox]
    )

if __name__ == "__main__":
    demo.launch(debug=True)