import os import gradio as gr import pandas as pd import plotly.express as px import plotly.graph_objects as go import boto3 import PyPDF2 import io import uuid import json import re import time import numpy as np import pdfplumber import requests from dotenv import load_dotenv from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from cassandra.query import SimpleStatement from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Cassandra from langchain_openai import OpenAIEmbeddings from PIL import Image, ImageDraw, ImageFont from astrapy.db import AstraDB as DataAPIClient # Load environment variables load_dotenv() # Global variables to store chat history and analytics data messages = [] product_images = [] current_product = "" query_counts = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0, "other": 0} daily_queries = [0, 0, 0, 0, 0, 6, 8, 10, 7, 9, 12, 15, 11, 14] # Mock data for chart # Initialize OpenAI API def init_openai_api(): """Initialize OpenAI API with API key from Hugging Face Secrets""" try: # Get API key from environment (set by Hugging Face Secrets) openai_api_key = os.getenv("OPENAI_API_KEY") if not openai_api_key: print("OPENAI_API_KEY is not set in environment variables") return False # Set as environment variable for libraries that use it directly os.environ["OPENAI_API_KEY"] = openai_api_key print("OpenAI API initialized with API key from Hugging Face Secrets") return True except Exception as e: print(f"Error initializing OpenAI API: {e}") return False # Initialize Mistral API def init_mistral_api(): """Initialize Mistral API with API key from Hugging Face Secrets""" try: # Get API key from environment (set by Hugging Face Secrets) mistral_api_key = os.getenv("MISTRAL_API_KEY") if not mistral_api_key: print("MISTRAL_API_KEY is not set in environment variables") return False # Set as environment variable for libraries that use it directly os.environ["MISTRAL_API_KEY"] = mistral_api_key print("Mistral API initialized with API key from Hugging Face Secrets") return True except Exception as e: print(f"Error initializing Mistral API: {e}") return False # Initialize Astra DB connection def init_astra_db(): """Initialize connection to Astra DB""" # Initialize collection variables at the very beginning db = None product_embeddings = None query_analytics = None product_images = None astra_db_keyspace = None try: # Get credentials from environment variables astra_db_id = os.getenv("ASTRA_DB_ID") astra_db_region = os.getenv("ASTRA_DB_REGION") astra_db_keyspace = os.getenv("ASTRA_DB_KEYSPACE") astra_db_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN") astra_db_endpoint = os.getenv("ASTRA_DB_ENDPOINT", "https://8e3fd85c-5f28-4e1f-8538-9dd28a3ea2b0-us-east-2.apps.astra.datastax.com") # Initialize the client db = DataAPIClient(api_endpoint=astra_db_endpoint, token=astra_db_application_token) # Try to create or access collections try: product_embeddings = db.collection("product_embeddings") query_analytics = db.collection("query_analytics") product_images = db.collection("product_images") print("Successfully created/accessed collections") except Exception as collection_error: print(f"Error creating collections: {collection_error}") print(f"Connected to Astra DB") except Exception as e: print(f"Error connecting to Astra DB: {e}") db = None # Always return a dictionary, even if there are errors return { "db": db, "keyspace": astra_db_keyspace, "collections": { "product_embeddings": product_embeddings, "query_analytics": query_analytics, "product_images": product_images } } # Initialize AWS S3 client for accessing product catalogs def init_s3_client(): """Initialize S3 client for accessing product catalogs""" try: s3_client = boto3.client( 's3', aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), region_name=os.getenv("AWS_REGION") ) return s3_client except Exception as e: print(f"Error initializing S3 client: {e}") return None # Initialize embedding model def get_embeddings_model(): """Initialize the OpenAI embeddings model for vector generation""" try: embeddings = OpenAIEmbeddings( model="text-embedding-ada-002", openai_api_key=os.getenv("OPENAI_API_KEY") ) return embeddings except Exception as e: print(f"Error initializing embeddings model: {e}") return None # Extract images from PDFs and store in Astra DB def extract_images_from_pdf(pdf_content, product_type): """Extract images from PDF using pdfplumber and store them in Astra DB""" if not astra_session: return 0 try: # Create a BytesIO object from the PDF content pdf_file = io.BytesIO(pdf_content) # Open the PDF with pdfplumber with pdfplumber.open(pdf_file) as pdf: images_stored = 0 # Iterate through each page for page_num, page in enumerate(pdf.pages): # Extract images from the page for img_index, img in enumerate(page.images): # Get image data image_bytes = img["stream"].get_data() # Skip small images if len(image_bytes) < 5000: continue # Generate a unique ID for the image image_id = str(uuid.uuid4()) # Store metadata metadata = json.dumps({ "product_type": product_type, "page_number": page_num, "image_index": img_index, "timestamp": time.time(), "image_size": len(image_bytes), "mime_type": "jpg" # Default to jpg for simplicity }) # Insert into Astra DB astra_session.execute( f""" INSERT INTO {astra_keyspace}.product_images (id, product_type, image_data, page_number, image_index, metadata) VALUES (%s, %s, %s, %s, %s, %s) """, (image_id, product_type, bytearray(image_bytes), page_num, img_index, metadata) ) images_stored += 1 return images_stored except Exception as e: print(f"Error extracting images from PDF: {e}") return 0 # Function to download and process PDFs from S3 def process_pdf_catalogs(): """Download and process PDF catalogs from S3 bucket""" if not s3_client: print("S3 client not initialized, skipping PDF processing") return {"status": "error", "message": "S3 client not initialized"} try: # Get list of PDF files in the bucket bucket_name = os.getenv("S3_BUCKET_NAME") response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix="catalogs/") pdf_files = [obj['Key'] for obj in response.get('Contents', []) if obj['Key'].endswith('.pdf')] processed_chunks = 0 processed_images = 0 # Process each PDF file for pdf_file in pdf_files: # Determine product type from filename product_type = "other" for pt in ["circuit_breaker", "motor_starter", "contactor", "switch", "relay"]: if pt in pdf_file.lower(): product_type = pt.replace("_", " ") break # Download PDF from S3 response = s3_client.get_object(Bucket=bucket_name, Key=pdf_file) pdf_content = response['Body'].read() # Process PDF text content pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_content)) text_content = "" # Extract text from each page for page in pdf_reader.pages: text_content += page.extract_text() + "\n\n" # Split text into smaller chunks for efficient embedding text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, length_function=len, ) chunks = text_splitter.split_text(text_content) # Store chunks in vector database store_chunks_in_db(chunks, product_type) # Extract and store images images_count = extract_images_from_pdf(pdf_content, product_type) processed_images += images_count processed_chunks += len(chunks) print(f"Processed {pdf_file}: {len(chunks)} text chunks and {images_count} images extracted") print(f"PDF processing complete: {len(pdf_files)} files, {processed_chunks} chunks, {processed_images} images") return { "status": "success", "files_processed": len(pdf_files), "chunks_processed": processed_chunks, "images_processed": processed_images } except Exception as e: print(f"Error processing PDF catalogs: {e}") return {"status": "error", "message": str(e)} # Add this function to process PDFs from URLs def process_pdf_from_url(url): """Download and process a PDF from a URL""" try: # Download the PDF response = requests.get(url, stream=True) if response.status_code != 200: return f"Error downloading PDF: HTTP status code {response.status_code}" # Get the content pdf_content = response.content # Determine product type from URL or filename product_type = "other" for pt in ["circuit_breaker", "motor_starter", "contactor", "switch", "relay"]: if pt in url.lower(): product_type = pt.replace("_", " ") break # Process PDF text content pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_content)) text_content = "" # Extract text from each page for page in pdf_reader.pages: text_content += page.extract_text() + "\n\n" # Split text into smaller chunks for efficient embedding text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, length_function=len, ) chunks = text_splitter.split_text(text_content) # Store chunks in vector database (if available) if astra_session: store_chunks_in_db(chunks, product_type) # Extract and store images (if database available) images_count = 0 if astra_session: images_count = extract_images_from_pdf(pdf_content, product_type) print(f"Processed PDF from URL: {url}: {len(chunks)} text chunks and {images_count} images extracted") return f"Successfully processed PDF from URL: {len(chunks)} chunks, {images_count} images" except Exception as e: print(f"Error processing PDF from URL: {e}") return f"Error processing PDF: {str(e)}" # Function to store text chunks in Astra DB with embeddings def store_chunks_in_db(chunks, product_type): """Store text chunks with embeddings in Astra DB""" if not astra_session or not embeddings_model: # Skip if database or embeddings model isn't available return try: # Process and store each chunk for chunk in chunks: # Generate embedding for the chunk embedding_vector = embeddings_model.embed_query(chunk) # Create a unique ID for the chunk chunk_id = str(uuid.uuid4()) # Create metadata metadata = json.dumps({ "product_type": product_type, "timestamp": time.time(), "char_count": len(chunk) }) # Insert into Astra DB astra_session.execute( f""" INSERT INTO {astra_keyspace}.product_embeddings (id, product_type, content, embedding_vector, metadata) VALUES (%s, %s, %s, %s, %s) """, (chunk_id, product_type, chunk, embedding_vector, metadata) ) except Exception as e: print(f"Error storing chunks in database: {e}") # Function to search for relevant product information in the vector database def search_vector_db(query, product_type=None, limit=5): """Search for relevant information in the vector database""" if not astra_session or not embeddings_model: # Return empty results if DB isn't available return [] try: # Generate embedding for the query query_embedding = embeddings_model.embed_query(query) # Prepare the CQL query cql_query = f""" SELECT id, product_type, content, embedding_vector FROM {astra_keyspace}.product_embeddings """ # Add product type filter if specified if product_type: cql_query += f" WHERE product_type = '{product_type}'" # Execute query to get all embeddings rows = astra_session.execute(cql_query) # Calculate similarity and rank results results = [] for row in rows: # Calculate cosine similarity db_embedding = row.embedding_vector similarity = np.dot(query_embedding, db_embedding) / ( np.linalg.norm(query_embedding) * np.linalg.norm(db_embedding) ) results.append({ "id": row.id, "product_type": row.product_type, "content": row.content, "similarity": similarity }) # Sort by similarity (highest first) and limit results results.sort(key=lambda x: x["similarity"], reverse=True) return results[:limit] except Exception as e: print(f"Error searching vector database: {e}") return [] def log_query_analytics(query, product_type, response_time): """Log query analytics to Astra DB""" if not astra_session: return try: query_id = str(uuid.uuid4()) astra_session.execute( f""" INSERT INTO {astra_keyspace}.query_analytics (id, query, product_type, timestamp, response_time) VALUES (%s, %s, %s, %s, %s) """, (query_id, query, product_type, time.time(), response_time) ) except Exception as e: print(f"Error logging query analytics: {e}") # Get product images from Astra DB def get_product_images(product): """Get product images from Astra DB, save them temporarily, and serve them""" global product_images if not astra_session: return [] try: # Query Astra DB for images related to the product query = f""" SELECT id, product_type, image_data, metadata FROM {astra_keyspace}.product_images WHERE product_type = %s LIMIT 4 """ rows = astra_session.execute(query, (product,)) # Store image URLs for display image_urls = [] for row in rows: image_id = row.id image_data = row.image_data # Save image data to a temporary file temp_dir = os.path.join(os.getcwd(), 'temp_images') os.makedirs(temp_dir, exist_ok=True) temp_path = os.path.join(temp_dir, f"image-{image_id}.jpg") with open(temp_path, 'wb') as f: f.write(image_data) # Create a URL that can be served by your web server image_url = f"/temp_images/image-{image_id}.jpg" image_urls.append(image_url) # If no images found, use placeholder URLs if not image_urls: image_urls = [ f"https://placeholder.com/abb-{product.lower().replace(' ', '-')}-1", f"https://placeholder.com/abb-{product.lower().replace(' ', '-')}-2" ] return image_urls except Exception as e: print(f"Error retrieving product images: {e}") return [] # Get response from OpenAI API def get_openai_response(query, context_chunks=None): """Get enhanced response from OpenAI model using RAG""" start_time = time.time() try: # Detect product type from query product_keywords = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0} detected_product = "other" for keyword in product_keywords: if keyword in query.lower(): product_keywords[keyword] += 1 if product_keywords[keyword] > product_keywords.get(detected_product, -1): detected_product = keyword # If no context chunks provided, search the vector DB if not context_chunks: context_chunks = search_vector_db(query, product_type=detected_product if detected_product != "other" else None) # Build context from retrieved chunks context_text = "\n\n".join([chunk["content"] for chunk in context_chunks]) if context_chunks else "" # Create prompt with context prompt = f""" You are an assistant specialized in ABB products and solutions. Answer the following query about ABB products with accurate and helpful information. Use the following product information to inform your response: {context_text} If the information above doesn't contain relevant details, use your general knowledge about industrial electrical equipment, but be clear about what information comes from the ABB catalog versus general knowledge. User query: {query} """ # Call OpenAI API headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" } payload = { "model": "gpt-4o", "messages": [ {"role": "system", "content": "You are an assistant specialized in ABB products and solutions."}, {"role": "user", "content": prompt} ], "temperature": 0.7, "max_tokens": 800 } response = requests.post( "https://api.openai.com/v1/chat/completions", headers=headers, json=payload ) if response.status_code == 200: response_json = response.json() response_text = response_json["choices"][0]["message"]["content"] else: # Fallback to Mistral if OpenAI fails print(f"OpenAI API error: {response.status_code}, {response.text}") response_text = get_mistral_response(query, context_chunks) # Update query counts for analytics if detected_product in query_counts: query_counts[detected_product] += 1 else: query_counts["other"] += 1 # Log analytics response_time = time.time() - start_time log_query_analytics(query, detected_product, response_time) return response_text, detected_product except Exception as e: print(f"Error processing chat request with OpenAI: {e}") # Fallback to Mistral try: return get_mistral_response(query, context_chunks) except: return "Sorry, I encountered an error processing your request. Please try again.", "other" # Get response from Mistral API (fallback) def get_mistral_response(query, context_chunks=None): """Get enhanced response from Mistral model using RAG (fallback)""" start_time = time.time() try: # Detect product type from query product_keywords = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0} detected_product = "other" for keyword in product_keywords: if keyword in query.lower(): product_keywords[keyword] += 1 if product_keywords[keyword] > product_keywords.get(detected_product, -1): detected_product = keyword # If no context chunks provided, search the vector DB if not context_chunks: context_chunks = search_vector_db(query, product_type=detected_product if detected_product != "other" else None) # Build context from retrieved chunks context_text = "\n\n".join([chunk["content"] for chunk in context_chunks]) if context_chunks else "" # Create prompt with context prompt = f""" You are an assistant specialized in ABB products and solutions. Answer the following query about ABB products with accurate and helpful information. Use the following product information to inform your response: {context_text} If the information above doesn't contain relevant details, use your general knowledge about industrial electrical equipment, but be clear about what information comes from the ABB catalog versus general knowledge. User query: {query} """ # Call Mistral API headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.getenv('MISTRAL_API_KEY')}" } payload = { "model": "mistral-large-latest", "messages": [ {"role": "system", "content": "You are an assistant specialized in ABB products and solutions."}, {"role": "user", "content": prompt} ], "temperature": 0.7, "max_tokens": 800 } response = requests.post( "https://api.mistral.ai/v1/chat/completions", headers=headers, json=payload ) if response.status_code == 200: response_json = response.json() response_text = response_json["choices"][0]["message"]["content"] else: print(f"Mistral API error: {response.status_code}, {response.text}") response_text = "Sorry, I encountered an error processing your request. Please try again." # Update query counts for analytics if detected_product in query_counts: query_counts[detected_product] += 1 else: query_counts["other"] += 1 # Log analytics response_time = time.time() - start_time log_query_analytics(query, detected_product, response_time) return response_text, detected_product except Exception as e: print(f"Error processing chat request with Mistral: {e}") return "Sorry, I encountered an error processing your request. Please try again.", "other" def process_message(query, history): """Process query using RAG and generate response with product images""" global messages, product_images, current_product if not query.strip(): return history # Get context from vector database context_chunks = search_vector_db(query) # Get LLM response with RAG (try OpenAI first, fallback to Mistral) try: response_text, detected_product = get_openai_response(query, context_chunks) except Exception as e: print(f"Error with OpenAI, falling back to Mistral: {e}") response_text, detected_product = get_mistral_response(query, context_chunks) # Format new history entry new_history = history.copy() new_history.append((query, response_text)) # Get product images if product detected if detected_product != "other": current_product = detected_product product_images = get_product_images(detected_product) else: product_images = [] # Update daily query data for analytics (in a real app, this would be in a database) daily_queries[-1] += 1 return new_history def reset_chat(history): """Reset the chat history""" return [] def process_pdfs_from_s3(bucket_name, prefix): """Process PDFs from S3 bucket""" # Set environment variable for S3 bucket os.environ["S3_BUCKET_NAME"] = bucket_name # Process PDFs result = process_pdf_catalogs() # Return result as string if result["status"] == "success": return f"Successfully processed {result['files_processed']} files, {result['chunks_processed']} chunks, and {result['images_processed']} images." else: return f"Error: {result['message']}" def render_images(): """Render product images as HTML (if available)""" if not product_images: return "" html = "
" for i, url in enumerate(product_images): html += f"""

{url}

""" html += "
" return html def setup_and_update(): """Setup the system and update status""" # Initialize APIs openai_initialized = init_openai_api() mistral_initialized = init_mistral_api() # Initialize database and other services global astra_session, astra_keyspace, s3_client, embeddings_model astra_result = init_astra_db() if astra_result: astra_session = astra_result.get("db") astra_keyspace = astra_result.get("keyspace") else: astra_session = None astra_keyspace = None s3_client = init_s3_client() embeddings_model = get_embeddings_model() # Return status status_msg = "System is ready. " if not openai_initialized: status_msg += "OpenAI API not initialized. " if not mistral_initialized: status_msg += "Mistral API not initialized. " if not astra_session: status_msg += "Astra DB not connected. " if not s3_client: status_msg += "S3 client not initialized. " return status_msg def create_gradio_app(): # Define CSS styles for a more modern, appealing interface custom_css = """ :root { --primary-color: #FF000C; --secondary-color: #212832; --background-color: var(--body-background-fill); --card-color: var(--block-background-fill); --text-color: var(--body-text-color); --border-radius: 12px; --shadow: 0 4px 12px rgba(0, 0, 0, 0.1); } .app-header { background-color: var(--secondary-color); padding: 20px; border-radius: var(--border-radius); margin-bottom: 20px; box-shadow: var(--shadow); display: flex; align-items: center; justify-content: space-between; } .app-header img { max-width: 120px; } .app-title { color: white; margin: 0; font-size: 24px; font-weight: 600; } .status-card, .catalog-card, .chat-card { background-color: var(--card-color); border-radius: var(--border-radius); padding: 15px; margin-bottom: 20px; box-shadow: var(--shadow); } .chat-card { height: 100%; } .message { padding: 10px 15px; border-radius: 8px; margin-bottom: 10px; max-width: 85%; } .user-message { background-color: var(--primary-color); color: white; margin-left: auto; } .bot-message { background-color: #f0f0f0; color: var(--text-color); margin-right: auto; } .footer { text-align: center; margin-top: 20px; font-size: 12px; color: var(--text-color); } .action-button { background-color: var(--primary-color); color: white; border: none; border-radius: var(--border-radius); padding: 8px 16px; cursor: pointer; transition: all 0.3s ease; } .action-button:hover { opacity: 0.9; } """ # Create the Gradio interface with gr.Blocks(css=custom_css) as app: # Setup status variable setup_status = gr.State("System is setting up. Please wait...") status_display = gr.Markdown("System is setting up. Please wait...") with gr.Column(scale=1): # Modern header with gr.Row(elem_classes="app-header"): with gr.Column(scale=1): gr.Image(value="https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/ABB_logo.svg/2560px-ABB_logo.svg.png", width=120, height=120, interactive=False, label="ABB Logo") with gr.Column(scale=3): gr.HTML('

Ginnie

') gr.HTML('

Your AI assistant for ABB product information

') # Chat interface with gr.Row(): with gr.Column(scale=3): # Chat interface with custom styling gr.HTML('
') chatbot = gr.Chatbot( value=[], elem_id="chatbot", height=500, show_copy_button=True, avatar_images=["https://ui-avatars.com/api/?name=You&background=0D8ABC&color=fff", "https://ui-avatars.com/api/?name=Ginnie&background=FF000C&color=fff"] ) # Message input with better styling with gr.Row(elem_classes="input-area"): msg = gr.Textbox( placeholder="Ask about ABB products...", label="", lines=2, max_lines=5, show_label=False ) send_btn = gr.Button("Send", elem_classes="primary-button") with gr.Row(): clear_btn = gr.Button("Clear Chat", elem_classes="secondary-button") gr.HTML('
') with gr.Column(scale=1): # Quick tips card gr.HTML('
') gr.HTML('''

Quick Tips

''') gr.HTML('
') # Admin settings with gr.Accordion("Admin Settings", open=False): with gr.Tab("Process PDFs"): s3_bucket = gr.Textbox(label="S3 Bucket Name") s3_prefix = gr.Textbox(label="S3 Prefix (folder)", value="catalogs/") process_btn = gr.Button("Process PDFs from S3", elem_classes="action-button") # Add direct PDF URL input with gr.Tab("Direct PDF URLs"): pdf_url = gr.Textbox(label="PDF URL", placeholder="https://example.com/sample.pdf") pdf_dropdown = gr.Dropdown( label="ABB Catalog PDFs", choices=[ "https://agent-product-discovery.s3.ap-south-1.amazonaws.com/ABB-catalog/ABB+Ability%E2%84%A2+System+800xA%C2%AE+6.2.pdf", "https://agent-product-discovery.s3.ap-south-1.amazonaws.com/ABB-catalog/Enclosed+Softstarters.pdf", "https://agent-product-discovery.s3.ap-south-1.amazonaws.com/ABB-catalog/Ex-Solutions.pdf", "https://agent-product-discovery.s3.ap-south-1.amazonaws.com/ABB-catalog/Low_power_UPS_catalogue_EN.pdf" ], interactive=True ) process_url_btn = gr.Button("Process PDF from URL", elem_classes="action-button") result_text = gr.Textbox(label="Processing Result") # Set up event handlers send_btn.click( process_message, [msg, chatbot], [chatbot], api_name="send_message" ) msg.submit( process_message, [msg, chatbot], [chatbot], api_name="send_message_enter" ) clear_btn.click( reset_chat, [chatbot], [chatbot], api_name="clear_chat" ) process_btn.click( process_pdfs_from_s3, [s3_bucket, s3_prefix], [result_text], api_name="process_pdfs" ) # Add this event handler process_url_btn.click( process_pdf_from_url, [pdf_url], [result_text], api_name="process_pdf_url" ) # Add this dropdown change event pdf_dropdown.change( lambda x: x, [pdf_dropdown], [pdf_url], api_name="update_pdf_url" ) # Add the system setup to run when the app loads app.load(setup_and_update, None, status_display) return app # Start the application if __name__ == "__main__": # Create and launch the UI demo = create_gradio_app() demo.launch(share=True)