Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import boto3 | |
| import PyPDF2 | |
| import io | |
| import uuid | |
| import json | |
| import re | |
| import time | |
| import numpy as np | |
| import pdfplumber | |
| import requests | |
| from dotenv import load_dotenv | |
| from cassandra.cluster import Cluster | |
| from cassandra.auth import PlainTextAuthProvider | |
| from cassandra.query import SimpleStatement | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Cassandra | |
| from langchain_openai import OpenAIEmbeddings | |
| from PIL import Image, ImageDraw, ImageFont | |
| from astrapy.db import AstraDB as DataAPIClient | |
| # Load environment variables | |
| load_dotenv() | |
| # Global variables to store chat history and analytics data | |
| messages = [] | |
| product_images = [] | |
| current_product = "" | |
| query_counts = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0, "other": 0} | |
| daily_queries = [0, 0, 0, 0, 0, 6, 8, 10, 7, 9, 12, 15, 11, 14] # Mock data for chart | |
| # Initialize OpenAI API | |
| def init_openai_api(): | |
| """Initialize OpenAI API with API key from Hugging Face Secrets""" | |
| try: | |
| # Get API key from environment (set by Hugging Face Secrets) | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| if not openai_api_key: | |
| print("OPENAI_API_KEY is not set in environment variables") | |
| return False | |
| # Set as environment variable for libraries that use it directly | |
| os.environ["OPENAI_API_KEY"] = openai_api_key | |
| print("OpenAI API initialized with API key from Hugging Face Secrets") | |
| return True | |
| except Exception as e: | |
| print(f"Error initializing OpenAI API: {e}") | |
| return False | |
| # Initialize Mistral API | |
| def init_mistral_api(): | |
| """Initialize Mistral API with API key from Hugging Face Secrets""" | |
| try: | |
| # Get API key from environment (set by Hugging Face Secrets) | |
| mistral_api_key = os.getenv("MISTRAL_API_KEY") | |
| if not mistral_api_key: | |
| print("MISTRAL_API_KEY is not set in environment variables") | |
| return False | |
| # Set as environment variable for libraries that use it directly | |
| os.environ["MISTRAL_API_KEY"] = mistral_api_key | |
| print("Mistral API initialized with API key from Hugging Face Secrets") | |
| return True | |
| except Exception as e: | |
| print(f"Error initializing Mistral API: {e}") | |
| return False | |
| # Initialize Astra DB connection | |
| def init_astra_db(): | |
| """Initialize connection to Astra DB""" | |
| # Initialize collection variables at the very beginning | |
| db = None | |
| product_embeddings = None | |
| query_analytics = None | |
| product_images = None | |
| astra_db_keyspace = None | |
| try: | |
| # Get credentials from environment variables | |
| astra_db_id = os.getenv("ASTRA_DB_ID") | |
| astra_db_region = os.getenv("ASTRA_DB_REGION") | |
| astra_db_keyspace = os.getenv("ASTRA_DB_KEYSPACE") | |
| astra_db_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN") | |
| astra_db_endpoint = os.getenv("ASTRA_DB_ENDPOINT", "https://8e3fd85c-5f28-4e1f-8538-9dd28a3ea2b0-us-east-2.apps.astra.datastax.com") | |
| # Initialize the client | |
| db = DataAPIClient(api_endpoint=astra_db_endpoint, token=astra_db_application_token) | |
| # Try to create or access collections | |
| try: | |
| product_embeddings = db.collection("product_embeddings") | |
| query_analytics = db.collection("query_analytics") | |
| product_images = db.collection("product_images") | |
| print("Successfully created/accessed collections") | |
| except Exception as collection_error: | |
| print(f"Error creating collections: {collection_error}") | |
| print(f"Connected to Astra DB") | |
| except Exception as e: | |
| print(f"Error connecting to Astra DB: {e}") | |
| db = None | |
| # Always return a dictionary, even if there are errors | |
| return { | |
| "db": db, | |
| "keyspace": astra_db_keyspace, | |
| "collections": { | |
| "product_embeddings": product_embeddings, | |
| "query_analytics": query_analytics, | |
| "product_images": product_images | |
| } | |
| } | |
| # Initialize AWS S3 client for accessing product catalogs | |
| def init_s3_client(): | |
| """Initialize S3 client for accessing product catalogs""" | |
| try: | |
| s3_client = boto3.client( | |
| 's3', | |
| aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), | |
| aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), | |
| region_name=os.getenv("AWS_REGION") | |
| ) | |
| return s3_client | |
| except Exception as e: | |
| print(f"Error initializing S3 client: {e}") | |
| return None | |
| # Initialize embedding model | |
| def get_embeddings_model(): | |
| """Initialize the OpenAI embeddings model for vector generation""" | |
| try: | |
| embeddings = OpenAIEmbeddings( | |
| model="text-embedding-ada-002", | |
| openai_api_key=os.getenv("OPENAI_API_KEY") | |
| ) | |
| return embeddings | |
| except Exception as e: | |
| print(f"Error initializing embeddings model: {e}") | |
| return None | |
| # Extract images from PDFs and store in Astra DB | |
| def extract_images_from_pdf(pdf_content, product_type): | |
| """Extract images from PDF using pdfplumber and store them in Astra DB""" | |
| if not astra_session: | |
| return 0 | |
| try: | |
| # Create a BytesIO object from the PDF content | |
| pdf_file = io.BytesIO(pdf_content) | |
| # Open the PDF with pdfplumber | |
| with pdfplumber.open(pdf_file) as pdf: | |
| images_stored = 0 | |
| # Iterate through each page | |
| for page_num, page in enumerate(pdf.pages): | |
| # Extract images from the page | |
| for img_index, img in enumerate(page.images): | |
| # Get image data | |
| image_bytes = img["stream"].get_data() | |
| # Skip small images | |
| if len(image_bytes) < 5000: | |
| continue | |
| # Generate a unique ID for the image | |
| image_id = str(uuid.uuid4()) | |
| # Store metadata | |
| metadata = json.dumps({ | |
| "product_type": product_type, | |
| "page_number": page_num, | |
| "image_index": img_index, | |
| "timestamp": time.time(), | |
| "image_size": len(image_bytes), | |
| "mime_type": "jpg" # Default to jpg for simplicity | |
| }) | |
| # Insert into Astra DB | |
| astra_session.execute( | |
| f""" | |
| INSERT INTO {astra_keyspace}.product_images | |
| (id, product_type, image_data, page_number, image_index, metadata) | |
| VALUES (%s, %s, %s, %s, %s, %s) | |
| """, | |
| (image_id, product_type, bytearray(image_bytes), page_num, img_index, metadata) | |
| ) | |
| images_stored += 1 | |
| return images_stored | |
| except Exception as e: | |
| print(f"Error extracting images from PDF: {e}") | |
| return 0 | |
| # Function to download and process PDFs from S3 | |
| def process_pdf_catalogs(): | |
| """Download and process PDF catalogs from S3 bucket""" | |
| if not s3_client: | |
| print("S3 client not initialized, skipping PDF processing") | |
| return {"status": "error", "message": "S3 client not initialized"} | |
| try: | |
| # Get list of PDF files in the bucket | |
| bucket_name = os.getenv("S3_BUCKET_NAME") | |
| response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix="catalogs/") | |
| pdf_files = [obj['Key'] for obj in response.get('Contents', []) if obj['Key'].endswith('.pdf')] | |
| processed_chunks = 0 | |
| processed_images = 0 | |
| # Process each PDF file | |
| for pdf_file in pdf_files: | |
| # Determine product type from filename | |
| product_type = "other" | |
| for pt in ["circuit_breaker", "motor_starter", "contactor", "switch", "relay"]: | |
| if pt in pdf_file.lower(): | |
| product_type = pt.replace("_", " ") | |
| break | |
| # Download PDF from S3 | |
| response = s3_client.get_object(Bucket=bucket_name, Key=pdf_file) | |
| pdf_content = response['Body'].read() | |
| # Process PDF text content | |
| pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_content)) | |
| text_content = "" | |
| # Extract text from each page | |
| for page in pdf_reader.pages: | |
| text_content += page.extract_text() + "\n\n" | |
| # Split text into smaller chunks for efficient embedding | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len, | |
| ) | |
| chunks = text_splitter.split_text(text_content) | |
| # Store chunks in vector database | |
| store_chunks_in_db(chunks, product_type) | |
| # Extract and store images | |
| images_count = extract_images_from_pdf(pdf_content, product_type) | |
| processed_images += images_count | |
| processed_chunks += len(chunks) | |
| print(f"Processed {pdf_file}: {len(chunks)} text chunks and {images_count} images extracted") | |
| print(f"PDF processing complete: {len(pdf_files)} files, {processed_chunks} chunks, {processed_images} images") | |
| return { | |
| "status": "success", | |
| "files_processed": len(pdf_files), | |
| "chunks_processed": processed_chunks, | |
| "images_processed": processed_images | |
| } | |
| except Exception as e: | |
| print(f"Error processing PDF catalogs: {e}") | |
| return {"status": "error", "message": str(e)} | |
| # Add this function to process PDFs from URLs | |
| def process_pdf_from_url(url): | |
| """Download and process a PDF from a URL""" | |
| try: | |
| # Download the PDF | |
| response = requests.get(url, stream=True) | |
| if response.status_code != 200: | |
| return f"Error downloading PDF: HTTP status code {response.status_code}" | |
| # Get the content | |
| pdf_content = response.content | |
| # Determine product type from URL or filename | |
| product_type = "other" | |
| for pt in ["circuit_breaker", "motor_starter", "contactor", "switch", "relay"]: | |
| if pt in url.lower(): | |
| product_type = pt.replace("_", " ") | |
| break | |
| # Process PDF text content | |
| pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_content)) | |
| text_content = "" | |
| # Extract text from each page | |
| for page in pdf_reader.pages: | |
| text_content += page.extract_text() + "\n\n" | |
| # Split text into smaller chunks for efficient embedding | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len, | |
| ) | |
| chunks = text_splitter.split_text(text_content) | |
| # Store chunks in vector database (if available) | |
| if astra_session: | |
| store_chunks_in_db(chunks, product_type) | |
| # Extract and store images (if database available) | |
| images_count = 0 | |
| if astra_session: | |
| images_count = extract_images_from_pdf(pdf_content, product_type) | |
| print(f"Processed PDF from URL: {url}: {len(chunks)} text chunks and {images_count} images extracted") | |
| return f"Successfully processed PDF from URL: {len(chunks)} chunks, {images_count} images" | |
| except Exception as e: | |
| print(f"Error processing PDF from URL: {e}") | |
| return f"Error processing PDF: {str(e)}" | |
| # Function to store text chunks in Astra DB with embeddings | |
| def store_chunks_in_db(chunks, product_type): | |
| """Store text chunks with embeddings in Astra DB""" | |
| if not astra_session or not embeddings_model: | |
| # Skip if database or embeddings model isn't available | |
| return | |
| try: | |
| # Process and store each chunk | |
| for chunk in chunks: | |
| # Generate embedding for the chunk | |
| embedding_vector = embeddings_model.embed_query(chunk) | |
| # Create a unique ID for the chunk | |
| chunk_id = str(uuid.uuid4()) | |
| # Create metadata | |
| metadata = json.dumps({ | |
| "product_type": product_type, | |
| "timestamp": time.time(), | |
| "char_count": len(chunk) | |
| }) | |
| # Insert into Astra DB | |
| astra_session.execute( | |
| f""" | |
| INSERT INTO {astra_keyspace}.product_embeddings | |
| (id, product_type, content, embedding_vector, metadata) | |
| VALUES (%s, %s, %s, %s, %s) | |
| """, | |
| (chunk_id, product_type, chunk, embedding_vector, metadata) | |
| ) | |
| except Exception as e: | |
| print(f"Error storing chunks in database: {e}") | |
| # Function to search for relevant product information in the vector database | |
| def search_vector_db(query, product_type=None, limit=5): | |
| """Search for relevant information in the vector database""" | |
| if not astra_session or not embeddings_model: | |
| # Return empty results if DB isn't available | |
| return [] | |
| try: | |
| # Generate embedding for the query | |
| query_embedding = embeddings_model.embed_query(query) | |
| # Prepare the CQL query | |
| cql_query = f""" | |
| SELECT id, product_type, content, embedding_vector | |
| FROM {astra_keyspace}.product_embeddings | |
| """ | |
| # Add product type filter if specified | |
| if product_type: | |
| cql_query += f" WHERE product_type = '{product_type}'" | |
| # Execute query to get all embeddings | |
| rows = astra_session.execute(cql_query) | |
| # Calculate similarity and rank results | |
| results = [] | |
| for row in rows: | |
| # Calculate cosine similarity | |
| db_embedding = row.embedding_vector | |
| similarity = np.dot(query_embedding, db_embedding) / ( | |
| np.linalg.norm(query_embedding) * np.linalg.norm(db_embedding) | |
| ) | |
| results.append({ | |
| "id": row.id, | |
| "product_type": row.product_type, | |
| "content": row.content, | |
| "similarity": similarity | |
| }) | |
| # Sort by similarity (highest first) and limit results | |
| results.sort(key=lambda x: x["similarity"], reverse=True) | |
| return results[:limit] | |
| except Exception as e: | |
| print(f"Error searching vector database: {e}") | |
| return [] | |
| def log_query_analytics(query, product_type, response_time): | |
| """Log query analytics to Astra DB""" | |
| if not astra_session: | |
| return | |
| try: | |
| query_id = str(uuid.uuid4()) | |
| astra_session.execute( | |
| f""" | |
| INSERT INTO {astra_keyspace}.query_analytics | |
| (id, query, product_type, timestamp, response_time) | |
| VALUES (%s, %s, %s, %s, %s) | |
| """, | |
| (query_id, query, product_type, time.time(), response_time) | |
| ) | |
| except Exception as e: | |
| print(f"Error logging query analytics: {e}") | |
| # Get product images from Astra DB | |
| def get_product_images(product): | |
| """Get product images from Astra DB, save them temporarily, and serve them""" | |
| global product_images | |
| if not astra_session: | |
| return [] | |
| try: | |
| # Query Astra DB for images related to the product | |
| query = f""" | |
| SELECT id, product_type, image_data, metadata | |
| FROM {astra_keyspace}.product_images | |
| WHERE product_type = %s | |
| LIMIT 4 | |
| """ | |
| rows = astra_session.execute(query, (product,)) | |
| # Store image URLs for display | |
| image_urls = [] | |
| for row in rows: | |
| image_id = row.id | |
| image_data = row.image_data | |
| # Save image data to a temporary file | |
| temp_dir = os.path.join(os.getcwd(), 'temp_images') | |
| os.makedirs(temp_dir, exist_ok=True) | |
| temp_path = os.path.join(temp_dir, f"image-{image_id}.jpg") | |
| with open(temp_path, 'wb') as f: | |
| f.write(image_data) | |
| # Create a URL that can be served by your web server | |
| image_url = f"/temp_images/image-{image_id}.jpg" | |
| image_urls.append(image_url) | |
| # If no images found, use placeholder URLs | |
| if not image_urls: | |
| image_urls = [ | |
| f"https://placeholder.com/abb-{product.lower().replace(' ', '-')}-1", | |
| f"https://placeholder.com/abb-{product.lower().replace(' ', '-')}-2" | |
| ] | |
| return image_urls | |
| except Exception as e: | |
| print(f"Error retrieving product images: {e}") | |
| return [] | |
| # Get response from OpenAI API | |
| def get_openai_response(query, context_chunks=None): | |
| """Get enhanced response from OpenAI model using RAG""" | |
| start_time = time.time() | |
| try: | |
| # Detect product type from query | |
| product_keywords = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0} | |
| detected_product = "other" | |
| for keyword in product_keywords: | |
| if keyword in query.lower(): | |
| product_keywords[keyword] += 1 | |
| if product_keywords[keyword] > product_keywords.get(detected_product, -1): | |
| detected_product = keyword | |
| # If no context chunks provided, search the vector DB | |
| if not context_chunks: | |
| context_chunks = search_vector_db(query, product_type=detected_product if detected_product != "other" else None) | |
| # Build context from retrieved chunks | |
| context_text = "\n\n".join([chunk["content"] for chunk in context_chunks]) if context_chunks else "" | |
| # Create prompt with context | |
| prompt = f""" | |
| You are an assistant specialized in ABB products and solutions. Answer the following query about ABB products with accurate and helpful information. | |
| Use the following product information to inform your response: | |
| {context_text} | |
| If the information above doesn't contain relevant details, use your general knowledge about industrial electrical equipment, but be clear about what information comes from the ABB catalog versus general knowledge. | |
| User query: {query} | |
| """ | |
| # Call OpenAI API | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" | |
| } | |
| payload = { | |
| "model": "gpt-4o", | |
| "messages": [ | |
| {"role": "system", "content": "You are an assistant specialized in ABB products and solutions."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| "temperature": 0.7, | |
| "max_tokens": 800 | |
| } | |
| response = requests.post( | |
| "https://api.openai.com/v1/chat/completions", | |
| headers=headers, | |
| json=payload | |
| ) | |
| if response.status_code == 200: | |
| response_json = response.json() | |
| response_text = response_json["choices"][0]["message"]["content"] | |
| else: | |
| # Fallback to Mistral if OpenAI fails | |
| print(f"OpenAI API error: {response.status_code}, {response.text}") | |
| response_text = get_mistral_response(query, context_chunks) | |
| # Update query counts for analytics | |
| if detected_product in query_counts: | |
| query_counts[detected_product] += 1 | |
| else: | |
| query_counts["other"] += 1 | |
| # Log analytics | |
| response_time = time.time() - start_time | |
| log_query_analytics(query, detected_product, response_time) | |
| return response_text, detected_product | |
| except Exception as e: | |
| print(f"Error processing chat request with OpenAI: {e}") | |
| # Fallback to Mistral | |
| try: | |
| return get_mistral_response(query, context_chunks) | |
| except: | |
| return "Sorry, I encountered an error processing your request. Please try again.", "other" | |
| # Get response from Mistral API (fallback) | |
| def get_mistral_response(query, context_chunks=None): | |
| """Get enhanced response from Mistral model using RAG (fallback)""" | |
| start_time = time.time() | |
| try: | |
| # Detect product type from query | |
| product_keywords = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0} | |
| detected_product = "other" | |
| for keyword in product_keywords: | |
| if keyword in query.lower(): | |
| product_keywords[keyword] += 1 | |
| if product_keywords[keyword] > product_keywords.get(detected_product, -1): | |
| detected_product = keyword | |
| # If no context chunks provided, search the vector DB | |
| if not context_chunks: | |
| context_chunks = search_vector_db(query, product_type=detected_product if detected_product != "other" else None) | |
| # Build context from retrieved chunks | |
| context_text = "\n\n".join([chunk["content"] for chunk in context_chunks]) if context_chunks else "" | |
| # Create prompt with context | |
| prompt = f""" | |
| You are an assistant specialized in ABB products and solutions. Answer the following query about ABB products with accurate and helpful information. | |
| Use the following product information to inform your response: | |
| {context_text} | |
| If the information above doesn't contain relevant details, use your general knowledge about industrial electrical equipment, but be clear about what information comes from the ABB catalog versus general knowledge. | |
| User query: {query} | |
| """ | |
| # Call Mistral API | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {os.getenv('MISTRAL_API_KEY')}" | |
| } | |
| payload = { | |
| "model": "mistral-large-latest", | |
| "messages": [ | |
| {"role": "system", "content": "You are an assistant specialized in ABB products and solutions."}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| "temperature": 0.7, | |
| "max_tokens": 800 | |
| } | |
| response = requests.post( | |
| "https://api.mistral.ai/v1/chat/completions", | |
| headers=headers, | |
| json=payload | |
| ) | |
| if response.status_code == 200: | |
| response_json = response.json() | |
| response_text = response_json["choices"][0]["message"]["content"] | |
| else: | |
| print(f"Mistral API error: {response.status_code}, {response.text}") | |
| response_text = "Sorry, I encountered an error processing your request. Please try again." | |
| # Update query counts for analytics | |
| if detected_product in query_counts: | |
| query_counts[detected_product] += 1 | |
| else: | |
| query_counts["other"] += 1 | |
| # Log analytics | |
| response_time = time.time() - start_time | |
| log_query_analytics(query, detected_product, response_time) | |
| return response_text, detected_product | |
| except Exception as e: | |
| print(f"Error processing chat request with Mistral: {e}") | |
| return "Sorry, I encountered an error processing your request. Please try again.", "other" | |
| def process_message(query, history): | |
| """Process query using RAG and generate response with product images""" | |
| global messages, product_images, current_product | |
| if not query.strip(): | |
| return history | |
| # Get context from vector database | |
| context_chunks = search_vector_db(query) | |
| # Get LLM response with RAG (try OpenAI first, fallback to Mistral) | |
| try: | |
| response_text, detected_product = get_openai_response(query, context_chunks) | |
| except Exception as e: | |
| print(f"Error with OpenAI, falling back to Mistral: {e}") | |
| response_text, detected_product = get_mistral_response(query, context_chunks) | |
| # Format new history entry | |
| new_history = history.copy() | |
| new_history.append((query, response_text)) | |
| # Get product images if product detected | |
| if detected_product != "other": | |
| current_product = detected_product | |
| product_images = get_product_images(detected_product) | |
| else: | |
| product_images = [] | |
| # Update daily query data for analytics (in a real app, this would be in a database) | |
| daily_queries[-1] += 1 | |
| return new_history | |
| def reset_chat(history): | |
| """Reset the chat history""" | |
| return [] | |
| def process_pdfs_from_s3(bucket_name, prefix): | |
| """Process PDFs from S3 bucket""" | |
| # Set environment variable for S3 bucket | |
| os.environ["S3_BUCKET_NAME"] = bucket_name | |
| # Process PDFs | |
| result = process_pdf_catalogs() | |
| # Return result as string | |
| if result["status"] == "success": | |
| return f"Successfully processed {result['files_processed']} files, {result['chunks_processed']} chunks, and {result['images_processed']} images." | |
| else: | |
| return f"Error: {result['message']}" | |
| def render_images(): | |
| """Render product images as HTML (if available)""" | |
| if not product_images: | |
| return "" | |
| html = "<div style='margin-top: 12px; display: grid; grid-template-columns: 1fr 1fr; gap: 8px;'>" | |
| for i, url in enumerate(product_images): | |
| html += f""" | |
| <div style='background: #f3f4f6; border-radius: 6px; padding: 8px; text-align: center;'> | |
| <div style='height: 100px; display: flex; align-items: center; justify-content: center; background: rgba(0,0,0,0.05); border-radius: 4px;'> | |
| <svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect width="18" height="18" x="3" y="3" rx="2" ry="2"/><circle cx="9" cy="9" r="2"/><path d="m21 15-3.086-3.086a2 2 0 0 0-2.828 0L6 21"/></svg> | |
| </div> | |
| <p style='margin-top: 4px; font-size: 12px;'>{url}</p> | |
| </div> | |
| """ | |
| html += "</div>" | |
| return html | |
| def setup_and_update(): | |
| """Setup the system and update status""" | |
| # Initialize APIs | |
| openai_initialized = init_openai_api() | |
| mistral_initialized = init_mistral_api() | |
| # Initialize database and other services | |
| global astra_session, astra_keyspace, s3_client, embeddings_model | |
| astra_result = init_astra_db() | |
| if astra_result: | |
| astra_session = astra_result.get("db") | |
| astra_keyspace = astra_result.get("keyspace") | |
| else: | |
| astra_session = None | |
| astra_keyspace = None | |
| s3_client = init_s3_client() | |
| embeddings_model = get_embeddings_model() | |
| # Return status | |
| status_msg = "System is ready. " | |
| if not openai_initialized: | |
| status_msg += "OpenAI API not initialized. " | |
| if not mistral_initialized: | |
| status_msg += "Mistral API not initialized. " | |
| if not astra_session: | |
| status_msg += "Astra DB not connected. " | |
| if not s3_client: | |
| status_msg += "S3 client not initialized. " | |
| return status_msg | |
| def create_gradio_app(): | |
| # Define CSS styles for a more modern, appealing interface | |
| custom_css = """ | |
| :root { | |
| --primary-color: #FF000C; | |
| --secondary-color: #212832; | |
| --background-color: var(--body-background-fill); | |
| --card-color: var(--block-background-fill); | |
| --text-color: var(--body-text-color); | |
| --border-radius: 12px; | |
| --shadow: 0 4px 12px rgba(0, 0, 0, 0.1); | |
| } | |
| .app-header { | |
| background-color: var(--secondary-color); | |
| padding: 20px; | |
| border-radius: var(--border-radius); | |
| margin-bottom: 20px; | |
| box-shadow: var(--shadow); | |
| display: flex; | |
| align-items: center; | |
| justify-content: space-between; | |
| } | |
| .app-header img { | |
| max-width: 120px; | |
| } | |
| .app-title { | |
| color: white; | |
| margin: 0; | |
| font-size: 24px; | |
| font-weight: 600; | |
| } | |
| .status-card, .catalog-card, .chat-card { | |
| background-color: var(--card-color); | |
| border-radius: var(--border-radius); | |
| padding: 15px; | |
| margin-bottom: 20px; | |
| box-shadow: var(--shadow); | |
| } | |
| .chat-card { | |
| height: 100%; | |
| } | |
| .message { | |
| padding: 10px 15px; | |
| border-radius: 8px; | |
| margin-bottom: 10px; | |
| max-width: 85%; | |
| } | |
| .user-message { | |
| background-color: var(--primary-color); | |
| color: white; | |
| margin-left: auto; | |
| } | |
| .bot-message { | |
| background-color: #f0f0f0; | |
| color: var(--text-color); | |
| margin-right: auto; | |
| } | |
| .footer { | |
| text-align: center; | |
| margin-top: 20px; | |
| font-size: 12px; | |
| color: var(--text-color); | |
| } | |
| .action-button { | |
| background-color: var(--primary-color); | |
| color: white; | |
| border: none; | |
| border-radius: var(--border-radius); | |
| padding: 8px 16px; | |
| cursor: pointer; | |
| transition: all 0.3s ease; | |
| } | |
| .action-button:hover { | |
| opacity: 0.9; | |
| } | |
| """ | |
| # Create the Gradio interface | |
| with gr.Blocks(css=custom_css) as app: | |
| # Setup status variable | |
| setup_status = gr.State("System is setting up. Please wait...") | |
| status_display = gr.Markdown("System is setting up. Please wait...") | |
| with gr.Column(scale=1): | |
| # Modern header | |
| with gr.Row(elem_classes="app-header"): | |
| with gr.Column(scale=1): | |
| gr.Image(value="https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/ABB_logo.svg/2560px-ABB_logo.svg.png", | |
| width=120, | |
| height=120, | |
| interactive=False, | |
| label="ABB Logo") | |
| with gr.Column(scale=3): | |
| gr.HTML('<h1 class="app-title">Ginnie</h1>') | |
| gr.HTML('<p class="app-subtitle">Your AI assistant for ABB product information</p>') | |
| # Chat interface | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| # Chat interface with custom styling | |
| gr.HTML('<div class="content-card">') | |
| chatbot = gr.Chatbot( | |
| value=[], | |
| elem_id="chatbot", | |
| height=500, | |
| show_copy_button=True, | |
| avatar_images=["https://ui-avatars.com/api/?name=You&background=0D8ABC&color=fff", | |
| "https://ui-avatars.com/api/?name=Ginnie&background=FF000C&color=fff"] | |
| ) | |
| # Message input with better styling | |
| with gr.Row(elem_classes="input-area"): | |
| msg = gr.Textbox( | |
| placeholder="Ask about ABB products...", | |
| label="", | |
| lines=2, | |
| max_lines=5, | |
| show_label=False | |
| ) | |
| send_btn = gr.Button("Send", elem_classes="primary-button") | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat", elem_classes="secondary-button") | |
| gr.HTML('</div>') | |
| with gr.Column(scale=1): | |
| # Quick tips card | |
| gr.HTML('<div class="status-card">') | |
| gr.HTML(''' | |
| <h3>Quick Tips</h3> | |
| <ul> | |
| <li>Ask about specific ABB products</li> | |
| <li>Inquire about technical specifications</li> | |
| <li>Ask about installation and maintenance</li> | |
| <li>Get help with troubleshooting</li> | |
| <li>S3 Bucket Name=agent-product-discovery</li> | |
| <li>ABB Ability™ System 800xA® 6.2.pdf, Enclosed Softstarters.pdf, Ex-Solutions.pdf, Low_power_UPS_catalogue_EN.pdf</li> | |
| </ul> | |
| ''') | |
| gr.HTML('</div>') | |
| # Admin settings | |
| with gr.Accordion("Admin Settings", open=False): | |
| with gr.Tab("Process PDFs"): | |
| s3_bucket = gr.Textbox(label="S3 Bucket Name") | |
| s3_prefix = gr.Textbox(label="S3 Prefix (folder)", value="catalogs/") | |
| process_btn = gr.Button("Process PDFs from S3", elem_classes="action-button") | |
| # Add direct PDF URL input | |
| with gr.Tab("Direct PDF URLs"): | |
| pdf_url = gr.Textbox(label="PDF URL", placeholder="https://example.com/sample.pdf") | |
| pdf_dropdown = gr.Dropdown( | |
| label="ABB Catalog PDFs", | |
| choices=[ | |
| "https://agent-product-discovery.s3.ap-south-1.amazonaws.com/ABB-catalog/ABB+Ability%E2%84%A2+System+800xA%C2%AE+6.2.pdf", | |
| "https://agent-product-discovery.s3.ap-south-1.amazonaws.com/ABB-catalog/Enclosed+Softstarters.pdf", | |
| "https://agent-product-discovery.s3.ap-south-1.amazonaws.com/ABB-catalog/Ex-Solutions.pdf", | |
| "https://agent-product-discovery.s3.ap-south-1.amazonaws.com/ABB-catalog/Low_power_UPS_catalogue_EN.pdf" | |
| ], | |
| interactive=True | |
| ) | |
| process_url_btn = gr.Button("Process PDF from URL", elem_classes="action-button") | |
| result_text = gr.Textbox(label="Processing Result") | |
| # Set up event handlers | |
| send_btn.click( | |
| process_message, | |
| [msg, chatbot], | |
| [chatbot], | |
| api_name="send_message" | |
| ) | |
| msg.submit( | |
| process_message, | |
| [msg, chatbot], | |
| [chatbot], | |
| api_name="send_message_enter" | |
| ) | |
| clear_btn.click( | |
| reset_chat, | |
| [chatbot], | |
| [chatbot], | |
| api_name="clear_chat" | |
| ) | |
| process_btn.click( | |
| process_pdfs_from_s3, | |
| [s3_bucket, s3_prefix], | |
| [result_text], | |
| api_name="process_pdfs" | |
| ) | |
| # Add this event handler | |
| process_url_btn.click( | |
| process_pdf_from_url, | |
| [pdf_url], | |
| [result_text], | |
| api_name="process_pdf_url" | |
| ) | |
| # Add this dropdown change event | |
| pdf_dropdown.change( | |
| lambda x: x, | |
| [pdf_dropdown], | |
| [pdf_url], | |
| api_name="update_pdf_url" | |
| ) | |
| # Add the system setup to run when the app loads | |
| app.load(setup_and_update, None, status_display) | |
| return app | |
| # Start the application | |
| if __name__ == "__main__": | |
| # Create and launch the UI | |
| demo = create_gradio_app() | |
| demo.launch(share=True) |