Anvit25's picture
Update app.py
6f9996c verified
import json
import os
from pathlib import Path
import numpy as np
import requests
from typing import Optional, List, Tuple, Dict
import gradio as gr
from dotenv import load_dotenv
from gradio_client import Client, handle_file
from sentence_transformers import SentenceTransformer
# --- 1. SETUP AND MODEL LOADING ---
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not found. It's needed for summarizing analysis results.")
print("Loading models and connecting to clients...")
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
try:
chatbot_client = Client("Anvit25/LLM_chatbot2")
audio_client = Client("Anvit25/new_audio")
vision_client = Client("Anvit25/vision-classifier")
print("All models and clients loaded successfully.")
except Exception as e:
print(f"FATAL: Failed to connect to a Gradio client: {e}")
exit()
# --- 2. LOAD & PRECOMPUTE DATA FOR LOCAL SEARCH & INTENT ---
# Load local image data
image_list = []
try:
with open("image.json", "r") as f:
data = json.load(f)
image_dir = Path("images")
if image_dir.exists():
for page in data.get("pages", []):
for img in page.get("images", []):
description = img.get("description", "")
if not description: continue
for img_file in image_dir.iterdir():
if img_file.is_file() and description.lower() in img_file.name.lower():
image_list.append({"file": str(img_file), "description": description})
break
print(f"Found {len(image_list)} local images for semantic search.")
print("Precomputing embeddings for local images...")
for img in image_list:
img["embedding"] = embedding_model.encode(img["description"])
print("Embeddings precomputed.")
else:
print("Warning: 'images' directory not found.")
except FileNotFoundError:
print("Warning: image.json not found.")
# Load intents from JSON for the new rule-based classifier
try:
with open("intents.json", "r") as f:
intents_data = json.load(f)
print("Local intent classifier phrases loaded successfully.")
except FileNotFoundError:
print("FATAL: intents.json not found. This file is required.")
exit()
# --- 3. HELPER FUNCTIONS ---
def get_user_intent_local(user_query: str) -> dict:
"""
CORRECTED: Uses a robust rule-based check to classify user intent.
This is much more reliable than the previous semantic search approach for intents.
"""
lower_query = user_query.lower()
# Iterate through intents and their trigger phrases
for intent, phrases in intents_data.items():
for phrase in phrases:
if phrase.lower() in lower_query:
# If a trigger phrase is found, identify the intent
subject = lower_query.replace(phrase.lower(), "").strip()
result = {
"intent": intent,
"query": subject if subject else user_query
}
print(f"Local Intent Classifier Result: {result}")
return result
# If no specific trigger phrase is found, default to a general chat
result = {"intent": "chat", "query": user_query}
print(f"Local Intent Classifier Result: {result}")
return result
def summarize_analysis_with_groq(json_result: dict, context: str) -> str:
"""Takes a JSON/dict result and uses Groq to create a human-readable summary."""
prompt = f"""
You are a helpful assistant. Based on the following technical analysis from a specialized AI model, provide a friendly and concise summary for the user.
Context: The user asked to '{context}'.
AI Model's Raw JSON Output:
```json
{json.dumps(json_result, indent=2)}
```
Your friendly, easy-to-understand summary:
"""
try:
response = requests.post(
"https://api.groq.com/openai/v1/chat/completions",
headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
json={"messages": [{"role": "user", "content": prompt}], "model": "llama-3.3-70b-versatile"},
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except Exception as e:
print(f"Groq summary error: {e}")
return f"I finished the analysis, but had trouble summarizing it. Here is the raw data:\n`{json.dumps(json_result)}`"
def cosine_similarity(vec1, vec2):
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0: return 0.0
return np.dot(vec1, vec2) / (norm1 * norm2)
def find_best_matching_image(query: str) -> Optional[dict]:
if not image_list: return None
query_emb = embedding_model.encode(query)
best_match = max(image_list, key=lambda img: cosine_similarity(query_emb, img.get("embedding", [])))
highest_score = cosine_similarity(query_emb, best_match.get("embedding", []))
if best_match and highest_score > 0.4:
return best_match
return None
def generate_groq_narrative(user_query: str, search_result: Optional[dict]) -> str:
if search_result:
prompt = f"A user asked to find: '{user_query}'. You found an image described as: '{search_result['description']}'. Craft a short, friendly response."
else:
prompt = f"A user asked to find: '{user_query}'. You searched but couldn't find a match. Craft a short, polite response."
try:
response = requests.post(
"https://api.groq.com/openai/v1/chat/completions",
headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
json={"messages": [{"role": "user", "content": prompt}], "model": "llama3-8b-8192"},
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except Exception as e:
print(f"Groq narrative error: {e}")
return "I had an issue describing the search result."
# --- 4. CORE GRADIO LOGIC ---
def handle_image_analysis(file_path: str) -> str:
try:
vision_result = vision_client.predict(image=handle_file(file_path), api_name="/predict")
summary = summarize_analysis_with_groq(vision_result, "Analyze this image")
return summary
except Exception as e:
return f"Sorry, I couldn't analyze the image. Error: {e}"
def handle_audio_analysis(file_path: str) -> str:
try:
prediction_text, _ = audio_client.predict(audio_filepath=handle_file(file_path), api_name="/predict")
return f"The audio analysis result is: **{prediction_text}**"
except Exception as e:
return f"Sorry, I couldn't analyze the audio. Error: {e}"
def chat_interface(user_input: dict, history: List[Tuple[str, str]]):
user_text = user_input["text"].strip()
user_files = user_input["files"]
new_history = history or []
# Priority 1: Handle file uploads
if user_files:
file_path = user_files[0]
new_history.append(((file_path,), None))
if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
bot_message = handle_image_analysis(file_path)
elif file_path.lower().endswith(('.wav', '.mp3', '.flac')):
bot_message = handle_audio_analysis(file_path)
else:
bot_message = "I'm not sure how to handle that file type."
new_history[-1] = (new_history[-1][0], bot_message)
return new_history, None
# Priority 2: Handle text-only queries
if not user_text:
return new_history, None
new_history.append((user_text, None))
intent_data = get_user_intent_local(user_text) # Use the NEW, robust classifier
intent = intent_data.get("intent")
query_subject = intent_data.get("query")
if intent == "chat":
try:
prediction = chatbot_client.predict(user_input=query_subject, api_name="/chatbot_response")
bot_message = prediction[-1]['content']
except Exception as e:
print(f"Error calling chatbot client: {e}")
bot_message = "I'm sorry, I'm having trouble connecting to my chat brain right now. Please try again."
elif intent == "search_local_image":
found_image = find_best_matching_image(query_subject)
bot_message = generate_groq_narrative(query_subject, found_image)
new_history[-1] = (user_text, bot_message)
if found_image:
new_history.append((None, (found_image['file'],)))
return new_history, None
elif intent == "request_image_analysis":
bot_message = "Of course. Please upload the image you want me to analyze."
elif intent == "request_audio_analysis":
bot_message = "I'm ready. Please upload the audio file for analysis."
else:
bot_message = "I'm not sure how to handle that. Can you rephrase?"
new_history[-1] = (user_text, bot_message)
return new_history, None
# --- 5. GRADIO UI DEFINITION ---
with gr.Blocks(theme=gr.themes.Soft(), title="Multi-Modal AI Chatbot") as demo:
gr.Markdown("# Multi-Modal AI Chatbot")
gr.Markdown("I can chat, search for local images, or analyze images and audio you upload.")
chatbot_history = gr.Chatbot(height=600, show_copy_button=True, layout="bubble", render=False)
with gr.Row():
multimodal_textbox = gr.MultimodalTextbox(
file_types=["image", "audio"],
placeholder="Type your message or upload a file...",
submit_btn="Send",
render=False,
autofocus=True
)
chatbot_history.render()
multimodal_textbox.render()
multimodal_textbox.submit(
fn=chat_interface,
inputs=[multimodal_textbox, chatbot_history],
outputs=[chatbot_history, multimodal_textbox]
)
if __name__ == "__main__":
demo.launch(debug=True)