Anvit25's picture
Added Gradio app.py
8dbca95
raw
history blame
22.5 kB
# import json
# import os
# from pathlib import Path
# import numpy as np
# from fastapi import FastAPI, Query
# from fastapi.responses import FileResponse
# # Use a dedicated library for creating text embeddings
# from sentence_transformers import SentenceTransformer
# # --- 1. Load the Local Embedding Model ---
# # This line downloads (first time only) and loads a powerful, lightweight model
# # into memory. This is much more efficient than using an API for this task.
# print("Loading sentence-transformer model...")
# embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# print("Model loaded successfully.")
# # --- 2. Load Image Metadata ---
# try:
# with open("image.json", "r") as f:
# data = json.load(f)
# except FileNotFoundError:
# print("Error: image.json not found. Please make sure the file exists.")
# exit()
# image_list = []
# image_dir = Path("images")
# if not image_dir.exists():
# print(f"Error: The '{image_dir}' directory does not exist.")
# exit()
# # Prepare list of images and descriptions
# for page in data.get("pages", []):
# for img in page.get("images", []):
# description = img.get("description", "")
# if not description:
# continue
# # Match description to a file in the images folder
# for img_file in image_dir.iterdir():
# if img_file.is_file() and description.lower() in img_file.name.lower():
# image_list.append({"file": str(img_file), "description": description})
# break # Move to the next description once a match is found
# print(f"Found {len(image_list)} images with matching descriptions.")
# # --- 3. Function to Get Embeddings Locally ---
# def get_embedding(text: str) -> np.ndarray:
# """
# Generates an embedding for the given text using the local SentenceTransformer model.
# """
# # The model.encode() method directly returns a numpy array. It's fast and local.
# return embedding_model.encode(text)
# # --- 4. Precompute Embeddings for All Images ---
# print("Precomputing embeddings for all image descriptions...")
# for img in image_list:
# # Each description is converted into a numerical vector (embedding)
# img["embedding"] = get_embedding(img["description"])
# print("Embeddings precomputed.")
# # --- 5. FastAPI Application ---
# app = FastAPI(title="Semantic Image Search API")
# def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
# """Calculates the cosine similarity between two vectors."""
# norm1 = np.linalg.norm(vec1)
# norm2 = np.linalg.norm(vec2)
# if norm1 == 0 or norm2 == 0:
# return 0.0
# return np.dot(vec1, vec2) / (norm1 * norm2)
# @app.get("/search_image/")
# async def search_image(query: str = Query(..., description="Search text")):
# # Convert the user's search query into an embedding
# query_emb = get_embedding(query)
# best_match = None
# highest_score = -1.0 # Cosine similarity ranges from -1 to 1
# # Compare the query embedding to all precomputed image description embeddings
# for img in image_list:
# score = cosine_similarity(query_emb, img["embedding"])
# if score > highest_score:
# highest_score = score
# best_match = img
# if best_match:
# print(f"Query: '{query}' -> Found best match: {best_match['file']} with score: {highest_score:.4f}")
# return FileResponse(best_match["file"])
# return {"error": "No matching image found"}
# import json
# import os
# from pathlib import Path
# import numpy as np
# import requests
# from typing import Optional, List, Tuple, Dict
# import gradio as gr
# from dotenv import load_dotenv
# from gradio_client import Client, handle_file
# from sentence_transformers import SentenceTransformer, util
# # --- 1. SETUP AND MODEL LOADING ---
# load_dotenv()
# GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Still used for summarizing results
# if not GROQ_API_KEY:
# raise ValueError("GROQ_API_KEY not found. It's needed for summarizing analysis results.")
# print("Loading models and connecting to clients...")
# # Model for local intent classification and image search
# embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# try:
# chatbot_client = Client("Anvit25/LLM_chatbot2")
# audio_client = Client("Anvit25/new_audio")
# vision_client = Client("Anvit25/vision-classifier")
# print("All models and clients loaded successfully.")
# except Exception as e:
# print(f"FATAL: Failed to connect to a Gradio client: {e}")
# exit()
# # --- 2. LOAD & PRECOMPUTE DATA FOR LOCAL SEARCH & INTENT ---
# # Load local image data (same as before)
# image_list = []
# # ... (Your image loading and embedding logic is unchanged) ...
# # NEW: Load intents from JSON and pre-compute their embeddings
# intent_embeddings = {}
# try:
# with open("intents.json", "r") as f:
# intents_data = json.load(f)
# for intent, phrases in intents_data.items():
# intent_embeddings[intent] = {
# "phrases": phrases,
# "embeddings": embedding_model.encode(phrases)
# }
# print("Local intent classifier loaded successfully.")
# except FileNotFoundError:
# print("FATAL: intents.json not found. This file is required for the local intent classifier.")
# exit()
# # --- 3. HELPER FUNCTIONS ---
# def get_user_intent_local(user_query: str) -> dict:
# """
# Uses SentenceTransformer to classify user intent locally based on intents.json.
# """
# query_embedding = embedding_model.encode(user_query)
# best_match = {"intent": "chat", "score": 0.7, "query": user_query} # Default to chat
# for intent, data in intent_embeddings.items():
# # Calculate cosine similarity between user query and all trigger phrases for an intent
# scores = util.cos_sim(query_embedding, data["embeddings"])[0]
# max_score = max(scores)
# if max_score > best_match["score"]:
# best_match["score"] = max_score.item()
# best_match["intent"] = intent
# # Extract the subject by removing the trigger phrase
# best_phrase_index = np.argmax(scores)
# trigger_phrase = data["phrases"][best_phrase_index]
# subject = user_query.lower().replace(trigger_phrase.lower(), "").strip()
# best_match["query"] = subject if subject else user_query
# print(f"Local Intent Classifier Result: {best_match}")
# return best_match
# def summarize_analysis_with_groq(json_result: dict, context: str) -> str:
# """
# NEW: Takes a JSON/dict result and uses Groq to create a human-readable summary.
# """
# prompt = f"""
# You are a helpful assistant. Based on the following technical analysis from a specialized AI model, provide a friendly and concise summary for the user.
# Context: The user asked to '{context}'.
# AI Model's Raw JSON Output:
# ```json
# {json.dumps(json_result, indent=2)}
# ```
# Your friendly, easy-to-understand summary:
# """
# try:
# response = requests.post(
# "https://api.groq.com/openai/v1/chat/completions",
# headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
# json={"messages": [{"role": "user", "content": prompt}], "model": "llama3-8b-8192"},
# )
# response.raise_for_status()
# return response.json()["choices"][0]["message"]["content"]
# except Exception as e:
# print(f"Groq summary error: {e}")
# return f"I finished the analysis, but had trouble summarizing it. Here is the raw data:\n`{json.dumps(json_result)}`"
# # ... (cosine_similarity and find_best_matching_image functions are unchanged) ...
# def find_best_matching_image(query: str) -> Optional[dict]: # ... (Identical) ...
# pass
# def generate_groq_narrative(user_query: str, search_result: Optional[dict]) -> str: # ... (Identical) ...
# pass
# # --- 4. CORE GRADIO LOGIC (UPDATED) ---
# def handle_image_analysis(file_path: str) -> str:
# """Analyzes an image and returns a text summary."""
# try:
# vision_result = vision_client.predict(image=handle_file(file_path), api_name="/predict")
# # NEW: Summarize the JSON result
# summary = summarize_analysis_with_groq(vision_result, "Analyze this image")
# return summary
# except Exception as e:
# return f"Sorry, I couldn't analyze the image. Error: {e}"
# def handle_audio_analysis(file_path: str) -> str:
# """Analyzes audio and returns a text summary."""
# try:
# prediction_text, _ = audio_client.predict(audio_filepath=handle_file(file_path), api_name="/predict")
# return f"The audio analysis result is: **{prediction_text}**"
# except Exception as e:
# return f"Sorry, I couldn't analyze the audio. Error: {e}"
# def chat_interface(user_input: dict, history: List[Tuple[str, str]]):
# """
# The main function that powers the Gradio chat interface.
# It now prioritizes file uploads over text for intent classification.
# """
# user_text = user_input["text"].strip()
# user_files = user_input["files"]
# new_history = history or []
# bot_message = ""
# # === Priority 1: Handle file uploads ===
# if user_files:
# file_path = user_files[0]
# # Display the uploaded file in the chat
# new_history.append(((file_path,), None))
# if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
# bot_message = handle_image_analysis(file_path)
# elif file_path.lower().endswith(('.wav', '.mp3', '.flac')):
# bot_message = handle_audio_analysis(file_path)
# else:
# bot_message = "I'm not sure how to handle that file type."
# new_history[-1] = (new_history[-1][0], bot_message)
# return new_history, None
# # === Priority 2: Handle text-only queries if no files are uploaded ===
# if not user_text:
# return new_history, None
# new_history.append((user_text, None))
# intent_data = get_user_intent_local(user_text) # Use local classifier
# intent = intent_data.get("intent")
# query_subject = intent_data.get("query")
# if intent == "chat":
# prediction = chatbot_client.predict(user_input=query_subject, api_name="/chatbot_response")
# bot_message = prediction[-1]['content']
# elif intent == "search_local_image":
# found_image = find_best_matching_image(query_subject)
# bot_message = generate_groq_narrative(query_subject, found_image)
# new_history[-1] = (user_text, bot_message)
# if found_image:
# new_history.append((None, (found_image['file'],))) # Display image on new line
# return new_history, None
# # For these intents, we just prompt the user to upload a file
# elif intent == "request_image_analysis":
# bot_message = "Of course. Please upload the image you want me to analyze."
# elif intent == "request_audio_analysis":
# bot_message = "I'm ready. Please upload the audio file for analysis."
# else:
# bot_message = "I'm not sure how to handle that. Can you rephrase?"
# new_history[-1] = (user_text, bot_message)
# return new_history, None
# # --- 5. GRADIO UI DEFINITION ---
# with gr.Blocks(theme=gr.themes.Soft(), title="Multi-Modal AI Chatbot") as demo:
# gr.Markdown("# Multi-Modal AI Chatbot")
# gr.Markdown("I can chat, search for local images, or analyze images and audio you upload.")
# # CORRECTED LINE: The 'bubble_fn' argument is removed.
# chatbot_history = gr.Chatbot(height=600, show_copy_button=True, layout="bubble", render=False)
# with gr.Row():
# multimodal_textbox = gr.MultimodalTextbox(
# file_types=["image", "audio"],
# placeholder="Type your message or upload a file...",
# submit_btn="Send",
# render=False,
# autofocus=True
# )
# # Render components after defining the layout
# chatbot_history.render()
# multimodal_textbox.render()
# multimodal_textbox.submit(
# fn=chat_interface,
# inputs=[multimodal_textbox, chatbot_history],
# outputs=[chatbot_history, multimodal_textbox]
# )
# if __name__ == "__main__":
# demo.launch(debug=True)
import json
import os
from pathlib import Path
import numpy as np
import requests
from typing import Optional, List, Tuple, Dict
import gradio as gr
from dotenv import load_dotenv
from gradio_client import Client, handle_file
from sentence_transformers import SentenceTransformer
# --- 1. SETUP AND MODEL LOADING ---
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY not found. It's needed for summarizing analysis results.")
print("Loading models and connecting to clients...")
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
try:
chatbot_client = Client("Anvit25/LLM_chatbot2")
audio_client = Client("Anvit25/new_audio")
vision_client = Client("Anvit25/vision-classifier")
print("All models and clients loaded successfully.")
except Exception as e:
print(f"FATAL: Failed to connect to a Gradio client: {e}")
exit()
# --- 2. LOAD & PRECOMPUTE DATA FOR LOCAL SEARCH & INTENT ---
# Load local image data
image_list = []
try:
with open("image.json", "r") as f:
data = json.load(f)
image_dir = Path("images")
if image_dir.exists():
for page in data.get("pages", []):
for img in page.get("images", []):
description = img.get("description", "")
if not description: continue
for img_file in image_dir.iterdir():
if img_file.is_file() and description.lower() in img_file.name.lower():
image_list.append({"file": str(img_file), "description": description})
break
print(f"Found {len(image_list)} local images for semantic search.")
print("Precomputing embeddings for local images...")
for img in image_list:
img["embedding"] = embedding_model.encode(img["description"])
print("Embeddings precomputed.")
else:
print("Warning: 'images' directory not found.")
except FileNotFoundError:
print("Warning: image.json not found.")
# Load intents from JSON for the new rule-based classifier
try:
with open("intents.json", "r") as f:
intents_data = json.load(f)
print("Local intent classifier phrases loaded successfully.")
except FileNotFoundError:
print("FATAL: intents.json not found. This file is required.")
exit()
# --- 3. HELPER FUNCTIONS ---
def get_user_intent_local(user_query: str) -> dict:
"""
CORRECTED: Uses a robust rule-based check to classify user intent.
This is much more reliable than the previous semantic search approach for intents.
"""
lower_query = user_query.lower()
# Iterate through intents and their trigger phrases
for intent, phrases in intents_data.items():
for phrase in phrases:
if phrase.lower() in lower_query:
# If a trigger phrase is found, identify the intent
subject = lower_query.replace(phrase.lower(), "").strip()
result = {
"intent": intent,
"query": subject if subject else user_query
}
print(f"Local Intent Classifier Result: {result}")
return result
# If no specific trigger phrase is found, default to a general chat
result = {"intent": "chat", "query": user_query}
print(f"Local Intent Classifier Result: {result}")
return result
def summarize_analysis_with_groq(json_result: dict, context: str) -> str:
"""Takes a JSON/dict result and uses Groq to create a human-readable summary."""
prompt = f"""
You are a helpful assistant. Based on the following technical analysis from a specialized AI model, provide a friendly and concise summary for the user.
Context: The user asked to '{context}'.
AI Model's Raw JSON Output:
```json
{json.dumps(json_result, indent=2)}
```
Your friendly, easy-to-understand summary:
"""
try:
response = requests.post(
"https://api.groq.com/openai/v1/chat/completions",
headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
json={"messages": [{"role": "user", "content": prompt}], "model": "llama-3.3-70b-versatile"},
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except Exception as e:
print(f"Groq summary error: {e}")
return f"I finished the analysis, but had trouble summarizing it. Here is the raw data:\n`{json.dumps(json_result)}`"
def cosine_similarity(vec1, vec2):
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0 or norm2 == 0: return 0.0
return np.dot(vec1, vec2) / (norm1 * norm2)
def find_best_matching_image(query: str) -> Optional[dict]:
if not image_list: return None
query_emb = embedding_model.encode(query)
best_match = max(image_list, key=lambda img: cosine_similarity(query_emb, img.get("embedding", [])))
highest_score = cosine_similarity(query_emb, best_match.get("embedding", []))
if best_match and highest_score > 0.4:
return best_match
return None
def generate_groq_narrative(user_query: str, search_result: Optional[dict]) -> str:
if search_result:
prompt = f"A user asked to find: '{user_query}'. You found an image described as: '{search_result['description']}'. Craft a short, friendly response."
else:
prompt = f"A user asked to find: '{user_query}'. You searched but couldn't find a match. Craft a short, polite response."
try:
response = requests.post(
"https://api.groq.com/openai/v1/chat/completions",
headers={"Authorization": f"Bearer {GROQ_API_KEY}"},
json={"messages": [{"role": "user", "content": prompt}], "model": "llama3-8b-8192"},
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except Exception as e:
print(f"Groq narrative error: {e}")
return "I had an issue describing the search result."
# --- 4. CORE GRADIO LOGIC ---
def handle_image_analysis(file_path: str) -> str:
try:
vision_result = vision_client.predict(image=handle_file(file_path), api_name="/predict")
summary = summarize_analysis_with_groq(vision_result, "Analyze this image")
return summary
except Exception as e:
return f"Sorry, I couldn't analyze the image. Error: {e}"
def handle_audio_analysis(file_path: str) -> str:
try:
prediction_text, _ = audio_client.predict(audio_filepath=handle_file(file_path), api_name="/predict")
return f"The audio analysis result is: **{prediction_text}**"
except Exception as e:
return f"Sorry, I couldn't analyze the audio. Error: {e}"
def chat_interface(user_input: dict, history: List[Tuple[str, str]]):
user_text = user_input["text"].strip()
user_files = user_input["files"]
new_history = history or []
# Priority 1: Handle file uploads
if user_files:
file_path = user_files[0]
new_history.append(((file_path,), None))
if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
bot_message = handle_image_analysis(file_path)
elif file_path.lower().endswith(('.wav', '.mp3', '.flac')):
bot_message = handle_audio_analysis(file_path)
else:
bot_message = "I'm not sure how to handle that file type."
new_history[-1] = (new_history[-1][0], bot_message)
return new_history, None
# Priority 2: Handle text-only queries
if not user_text:
return new_history, None
new_history.append((user_text, None))
intent_data = get_user_intent_local(user_text) # Use the NEW, robust classifier
intent = intent_data.get("intent")
query_subject = intent_data.get("query")
if intent == "chat":
try:
prediction = chatbot_client.predict(user_input=query_subject, api_name="/chatbot_response")
bot_message = prediction[-1]['content']
except Exception as e:
print(f"Error calling chatbot client: {e}")
bot_message = "I'm sorry, I'm having trouble connecting to my chat brain right now. Please try again."
elif intent == "search_local_image":
found_image = find_best_matching_image(query_subject)
bot_message = generate_groq_narrative(query_subject, found_image)
new_history[-1] = (user_text, bot_message)
if found_image:
new_history.append((None, (found_image['file'],)))
return new_history, None
elif intent == "request_image_analysis":
bot_message = "Of course. Please upload the image you want me to analyze."
elif intent == "request_audio_analysis":
bot_message = "I'm ready. Please upload the audio file for analysis."
else:
bot_message = "I'm not sure how to handle that. Can you rephrase?"
new_history[-1] = (user_text, bot_message)
return new_history, None
# --- 5. GRADIO UI DEFINITION ---
with gr.Blocks(theme=gr.themes.Soft(), title="Multi-Modal AI Chatbot") as demo:
gr.Markdown("# Multi-Modal AI Chatbot")
gr.Markdown("I can chat, search for local images, or analyze images and audio you upload.")
chatbot_history = gr.Chatbot(height=600, show_copy_button=True, layout="bubble", render=False)
with gr.Row():
multimodal_textbox = gr.MultimodalTextbox(
file_types=["image", "audio"],
placeholder="Type your message or upload a file...",
submit_btn="Send",
render=False,
autofocus=True
)
chatbot_history.render()
multimodal_textbox.render()
multimodal_textbox.submit(
fn=chat_interface,
inputs=[multimodal_textbox, chatbot_history],
outputs=[chatbot_history, multimodal_textbox]
)
if __name__ == "__main__":
demo.launch(debug=True)