prernajeet01 commited on
Commit
e448a08
·
verified ·
1 Parent(s): 9f66f7c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +937 -0
app.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import google.generativeai as genai
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ import boto3
8
+ import PyPDF2
9
+ import io
10
+ import uuid
11
+ import json
12
+ import re
13
+ import time
14
+ import numpy as np
15
+ import fitz # PyMuPDF for PDF image extraction
16
+ from dotenv import load_dotenv
17
+ from cassandra.cluster import Cluster
18
+ from cassandra.auth import PlainTextAuthProvider
19
+ from cassandra.query import SimpleStatement
20
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
21
+ from langchain_community.vectorstores import Cassandra
22
+ from langchain_community.embeddings import VertexAIEmbeddings
23
+ from google.oauth2 import service_account
24
+
25
+ # Load environment variables
26
+ load_dotenv()
27
+
28
+ # Global variables to store chat history and analytics data
29
+ messages = []
30
+ product_images = []
31
+ current_product = ""
32
+ query_counts = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0, "other": 0}
33
+ daily_queries = [0, 0, 0, 0, 0, 6, 8, 10, 7, 9, 12, 15, 11, 14] # Mock data for chart
34
+
35
+ # Initialize Gemini API with service account credentials
36
+ def init_gemini_api():
37
+ """Initialize Google Gemini API with service account credentials"""
38
+ try:
39
+ # Load credentials from service account JSON file
40
+ credentials_path = os.getenv("GOOGLE_APPLICATION_CREDENTIALS")
41
+ credentials = service_account.Credentials.from_service_account_file(
42
+ credentials_path,
43
+ scopes=["https://www.googleapis.com/auth/cloud-platform"]
44
+ )
45
+
46
+ # Configure Gemini API with credentials
47
+ genai.configure(credentials=credentials)
48
+ print("Gemini API initialized with service account credentials")
49
+ return True
50
+ except Exception as e:
51
+ print(f"Error initializing Gemini API: {e}")
52
+ # Fallback to API key method if service account fails
53
+ try:
54
+ genai.configure(api_key=os.getenv("GEMINI_API_KEY", ""))
55
+ print("Gemini API initialized with API key")
56
+ return True
57
+ except Exception as e2:
58
+ print(f"Fallback to API key also failed: {e2}")
59
+ return False
60
+
61
+ # Initialize Astra DB connection
62
+ def init_astra_db():
63
+ """Initialize connection to Astra DB"""
64
+ try:
65
+ # Get credentials from environment variables
66
+ astra_db_id = os.getenv("ASTRA_DB_ID")
67
+ astra_db_region = os.getenv("ASTRA_DB_REGION")
68
+ astra_db_keyspace = os.getenv("ASTRA_DB_KEYSPACE")
69
+ astra_db_application_token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
70
+
71
+ # Setup the connection
72
+ cloud_config = {
73
+ 'secure_connect_bundle': 'secure-connect-' + astra_db_id + '.zip'
74
+ }
75
+
76
+ auth_provider = PlainTextAuthProvider(
77
+ 'token', astra_db_application_token)
78
+ cluster = Cluster(cloud=cloud_config, auth_provider=auth_provider)
79
+ session = cluster.connect()
80
+
81
+ # Create keyspace if it doesn't exist
82
+ session.execute(f"""
83
+ CREATE KEYSPACE IF NOT EXISTS {astra_db_keyspace}
84
+ WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': '3'}}
85
+ """)
86
+
87
+ # Create table for vector embeddings if it doesn't exist
88
+ session.execute(f"""
89
+ CREATE TABLE IF NOT EXISTS {astra_db_keyspace}.product_embeddings (
90
+ id text PRIMARY KEY,
91
+ product_type text,
92
+ content text,
93
+ embedding_vector list<float>,
94
+ metadata text
95
+ )
96
+ """)
97
+
98
+ # Create table for query analytics
99
+ session.execute(f"""
100
+ CREATE TABLE IF NOT EXISTS {astra_db_keyspace}.query_analytics (
101
+ id text PRIMARY KEY,
102
+ query text,
103
+ product_type text,
104
+ timestamp timestamp,
105
+ response_time float
106
+ )
107
+ """)
108
+
109
+ # Create table for product images
110
+ session.execute(f"""
111
+ CREATE TABLE IF NOT EXISTS {astra_db_keyspace}.product_images (
112
+ id text PRIMARY KEY,
113
+ product_type text,
114
+ image_data blob,
115
+ page_number int,
116
+ image_index int,
117
+ metadata text
118
+ )
119
+ """)
120
+
121
+ print("Astra DB connection established")
122
+ return session, astra_db_keyspace
123
+ except Exception as e:
124
+ print(f"Error connecting to Astra DB: {e}")
125
+ # Return None values to allow the app to run without DB connection
126
+ return None, None
127
+
128
+ # Initialize AWS S3 client for accessing product catalogs
129
+ def init_s3_client():
130
+ """Initialize S3 client for accessing product catalogs"""
131
+ try:
132
+ s3_client = boto3.client(
133
+ 's3',
134
+ aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
135
+ aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
136
+ region_name=os.getenv("AWS_REGION")
137
+ )
138
+ return s3_client
139
+ except Exception as e:
140
+ print(f"Error initializing S3 client: {e}")
141
+ return None
142
+
143
+ # Initialize embedding model
144
+ def get_embeddings_model():
145
+ """Initialize the embeddings model for vector generation"""
146
+ try:
147
+ embeddings = VertexAIEmbeddings(
148
+ project=os.getenv("GOOGLE_CLOUD_PROJECT"),
149
+ location=os.getenv("GOOGLE_CLOUD_LOCATION")
150
+ )
151
+ return embeddings
152
+ except Exception as e:
153
+ print(f"Error initializing embeddings model: {e}")
154
+ return None
155
+
156
+ # Extract images from PDFs and store in Astra DB
157
+ def extract_images_from_pdf(pdf_content, product_type):
158
+ """Extract images from PDF and store them in Astra DB"""
159
+ if not astra_session:
160
+ return 0
161
+
162
+ try:
163
+ # Open PDF from bytes
164
+ pdf_document = fitz.open(stream=pdf_content, filetype="pdf")
165
+ images_stored = 0
166
+
167
+ # Extract images from each page
168
+ for page_num in range(len(pdf_document)):
169
+ page = pdf_document[page_num]
170
+ image_list = page.get_images(full=True)
171
+
172
+ for img_index, img_info in enumerate(image_list):
173
+ # Extract image
174
+ xref = img_info[0]
175
+ base_image = pdf_document.extract_image(xref)
176
+ image_bytes = base_image["image"]
177
+
178
+ # Skip very small images (likely icons or decorative elements)
179
+ if len(image_bytes) < 5000: # Skip images smaller than ~5KB
180
+ continue
181
+
182
+ # Generate a unique ID for the image
183
+ image_id = str(uuid.uuid4())
184
+
185
+ # Store metadata
186
+ metadata = json.dumps({
187
+ "product_type": product_type,
188
+ "page_number": page_num,
189
+ "image_index": img_index,
190
+ "timestamp": time.time(),
191
+ "image_size": len(image_bytes),
192
+ "mime_type": base_image["ext"]
193
+ })
194
+
195
+ # Insert into Astra DB
196
+ astra_session.execute(
197
+ f"""
198
+ INSERT INTO {astra_keyspace}.product_images
199
+ (id, product_type, image_data, page_number, image_index, metadata)
200
+ VALUES (%s, %s, %s, %s, %s, %s)
201
+ """,
202
+ (image_id, product_type, bytearray(image_bytes), page_num, img_index, metadata)
203
+ )
204
+ images_stored += 1
205
+
206
+ pdf_document.close()
207
+ return images_stored
208
+ except Exception as e:
209
+ print(f"Error extracting images from PDF: {e}")
210
+ return 0
211
+
212
+ # Function to download and process PDFs from S3
213
+ def process_pdf_catalogs():
214
+ """Download and process PDF catalogs from S3 bucket"""
215
+ if not s3_client:
216
+ print("S3 client not initialized, skipping PDF processing")
217
+ return {"status": "error", "message": "S3 client not initialized"}
218
+
219
+ try:
220
+ # Get list of PDF files in the bucket
221
+ bucket_name = os.getenv("S3_BUCKET_NAME")
222
+ response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix="catalogs/")
223
+
224
+ pdf_files = [obj['Key'] for obj in response.get('Contents', []) if obj['Key'].endswith('.pdf')]
225
+
226
+ processed_chunks = 0
227
+ processed_images = 0
228
+
229
+ # Process each PDF file
230
+ for pdf_file in pdf_files:
231
+ # Determine product type from filename
232
+ product_type = "other"
233
+ for pt in ["circuit_breaker", "motor_starter", "contactor", "switch", "relay"]:
234
+ if pt in pdf_file.lower():
235
+ product_type = pt.replace("_", " ")
236
+ break
237
+
238
+ # Download PDF from S3
239
+ response = s3_client.get_object(Bucket=bucket_name, Key=pdf_file)
240
+ pdf_content = response['Body'].read()
241
+
242
+ # Process PDF text content
243
+ pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_content))
244
+ text_content = ""
245
+
246
+ # Extract text from each page
247
+ for page in pdf_reader.pages:
248
+ text_content += page.extract_text() + "\n\n"
249
+
250
+ # Split text into smaller chunks for efficient embedding
251
+ text_splitter = RecursiveCharacterTextSplitter(
252
+ chunk_size=1000,
253
+ chunk_overlap=200,
254
+ length_function=len,
255
+ )
256
+ chunks = text_splitter.split_text(text_content)
257
+
258
+ # Store chunks in vector database
259
+ store_chunks_in_db(chunks, product_type)
260
+
261
+ # Extract and store images
262
+ images_count = extract_images_from_pdf(pdf_content, product_type)
263
+ processed_images += images_count
264
+
265
+ processed_chunks += len(chunks)
266
+ print(f"Processed {pdf_file}: {len(chunks)} text chunks and {images_count} images extracted")
267
+
268
+ print(f"PDF processing complete: {len(pdf_files)} files, {processed_chunks} chunks, {processed_images} images")
269
+ return {
270
+ "status": "success",
271
+ "files_processed": len(pdf_files),
272
+ "chunks_processed": processed_chunks,
273
+ "images_processed": processed_images
274
+ }
275
+ except Exception as e:
276
+ print(f"Error processing PDF catalogs: {e}")
277
+ return {"status": "error", "message": str(e)}
278
+
279
+ # Function to store text chunks in Astra DB with embeddings
280
+ def store_chunks_in_db(chunks, product_type):
281
+ """Store text chunks with embeddings in Astra DB"""
282
+ if not astra_session or not embeddings_model:
283
+ # Skip if database or embeddings model isn't available
284
+ return
285
+
286
+ try:
287
+ # Process and store each chunk
288
+ for chunk in chunks:
289
+ # Generate embedding for the chunk
290
+ embedding_vector = embeddings_model.embed_query(chunk)
291
+
292
+ # Create a unique ID for the chunk
293
+ chunk_id = str(uuid.uuid4())
294
+
295
+ # Create metadata
296
+ metadata = json.dumps({
297
+ "product_type": product_type,
298
+ "timestamp": time.time(),
299
+ "char_count": len(chunk)
300
+ })
301
+
302
+ # Insert into Astra DB
303
+ astra_session.execute(
304
+ f"""
305
+ INSERT INTO {astra_keyspace}.product_embeddings
306
+ (id, product_type, content, embedding_vector, metadata)
307
+ VALUES (%s, %s, %s, %s, %s)
308
+ """,
309
+ (chunk_id, product_type, chunk, embedding_vector, metadata)
310
+ )
311
+ except Exception as e:
312
+ print(f"Error storing chunks in database: {e}")
313
+
314
+ # Function to search for relevant product information in the vector database
315
+ def search_vector_db(query, product_type=None, limit=5):
316
+ """Search for relevant information in the vector database"""
317
+ if not astra_session or not embeddings_model:
318
+ # Return empty results if DB isn't available
319
+ return []
320
+
321
+ try:
322
+ # Generate embedding for the query
323
+ query_embedding = embeddings_model.embed_query(query)
324
+
325
+ # Prepare the CQL query
326
+ cql_query = f"""
327
+ SELECT id, product_type, content, embedding_vector
328
+ FROM {astra_keyspace}.product_embeddings
329
+ """
330
+
331
+ # Add product type filter if specified
332
+ if product_type:
333
+ cql_query += f" WHERE product_type = '{product_type}'"
334
+
335
+ # Execute query to get all embeddings
336
+ rows = astra_session.execute(cql_query)
337
+
338
+ # Calculate similarity and rank results
339
+ results = []
340
+ for row in rows:
341
+ # Calculate cosine similarity
342
+ db_embedding = row.embedding_vector
343
+ similarity = np.dot(query_embedding, db_embedding) / (
344
+ np.linalg.norm(query_embedding) * np.linalg.norm(db_embedding)
345
+ )
346
+
347
+ results.append({
348
+ "id": row.id,
349
+ "product_type": row.product_type,
350
+ "content": row.content,
351
+ "similarity": similarity
352
+ })
353
+
354
+ # Sort by similarity (highest first) and limit results
355
+ results.sort(key=lambda x: x["similarity"], reverse=True)
356
+ return results[:limit]
357
+ except Exception as e:
358
+ print(f"Error searching vector database: {e}")
359
+ return []
360
+
361
+ def log_query_analytics(query, product_type, response_time):
362
+ """Log query analytics to Astra DB"""
363
+ if not astra_session:
364
+ return
365
+
366
+ try:
367
+ query_id = str(uuid.uuid4())
368
+ astra_session.execute(
369
+ f"""
370
+ INSERT INTO {astra_keyspace}.query_analytics
371
+ (id, query, product_type, timestamp, response_time)
372
+ VALUES (%s, %s, %s, %s, %s)
373
+ """,
374
+ (query_id, query, product_type, time.time(), response_time)
375
+ )
376
+ except Exception as e:
377
+ print(f"Error logging query analytics: {e}")
378
+
379
+ # Get product images from Astra DB
380
+ def get_product_images(product):
381
+ """Get product images from Astra DB"""
382
+ global product_images
383
+
384
+ if not astra_session:
385
+ return []
386
+
387
+ try:
388
+ # Query Astra DB for images related to the product
389
+ query = f"""
390
+ SELECT id, product_type, image_data, metadata
391
+ FROM {astra_keyspace}.product_images
392
+ WHERE product_type = %s
393
+ LIMIT 4
394
+ """
395
+
396
+ rows = astra_session.execute(query, (product,))
397
+
398
+ # Store image URLs (or IDs) for display
399
+ image_urls = []
400
+ for row in rows:
401
+ # In a real implementation, you would save the image temporarily and serve it
402
+ # For this demo, we're just using the image ID as an identifier
403
+ image_id = row.id
404
+ image_urls.append(f"image-{image_id[:8]}")
405
+
406
+ # If no images found, use placeholder URLs
407
+ if not image_urls:
408
+ image_urls = [
409
+ f"https://placeholder.com/abb-{product.lower().replace(' ', '-')}-1",
410
+ f"https://placeholder.com/abb-{product.lower().replace(' ', '-')}-2"
411
+ ]
412
+
413
+ return image_urls
414
+ except Exception as e:
415
+ print(f"Error retrieving product images: {e}")
416
+ return []
417
+
418
+ # Analyze product image with Gemini Vision
419
+ def analyze_product_image_with_vision(image_data, query):
420
+ """Analyze product image using Gemini Pro Vision"""
421
+ if not image_data:
422
+ return "No image data available for analysis"
423
+
424
+ try:
425
+ # Use Gemini 1.0 Pro Vision model
426
+ model_name = "gemini-1.0-pro-vision-001"
427
+ model = genai.GenerativeModel(model_name)
428
+
429
+ # Create a vision-enabled prompt
430
+ response = model.generate_content([
431
+ "Analyze this ABB product image and answer the following question:",
432
+ query,
433
+ genai.types.Part.from_data(image_data, mime_type="image/jpeg")
434
+ ])
435
+
436
+ return response.text
437
+ except Exception as e:
438
+ print(f"Error analyzing image with Gemini Vision: {e}")
439
+ return "Error analyzing image. Please try a different query."
440
+
441
+ def get_gemini_response(query, context_chunks=None):
442
+ """Get enhanced response from Gemini model using RAG"""
443
+ start_time = time.time()
444
+
445
+ try:
446
+ # Set up the model
447
+ model_name = "gemini-2.0-flash-001"
448
+ model = genai.GenerativeModel(model_name)
449
+
450
+ # Detect product type from query
451
+ product_keywords = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0}
452
+ detected_product = "other"
453
+
454
+ for keyword in product_keywords:
455
+ if keyword in query.lower():
456
+ product_keywords[keyword] += 1
457
+ if product_keywords[keyword] > product_keywords.get(detected_product, -1):
458
+ detected_product = keyword
459
+
460
+ # If no context chunks provided, search the vector DB
461
+ if not context_chunks:
462
+ context_chunks = search_vector_db(query, product_type=detected_product if detected_product != "other" else None)
463
+
464
+ # Build context from retrieved chunks
465
+ context_text = "\n\n".join([chunk["content"] for chunk in context_chunks]) if context_chunks else ""
466
+
467
+ # Create prompt with context
468
+ prompt = f"""
469
+ You are an assistant specialized in ABB products and solutions. Answer the following query about ABB products with accurate and helpful information.
470
+
471
+ Use the following product information to inform your response:
472
+ {context_text}
473
+
474
+ If the information above doesn't contain relevant details, use your general knowledge about industrial electrical equipment, but be clear about what information comes from the ABB catalog versus general knowledge.
475
+
476
+ User query: {query}
477
+ """
478
+
479
+ # Generate response using Gemini
480
+ response = model.generate_content(prompt)
481
+
482
+ # Update query counts for analytics
483
+ if detected_product in query_counts:
484
+ query_counts[detected_product] += 1
485
+ else:
486
+ query_counts["other"] += 1
487
+
488
+ # Log analytics
489
+ response_time = time.time() - start_time
490
+ log_query_analytics(query, detected_product, response_time)
491
+
492
+ return response.text, detected_product
493
+ except Exception as e:
494
+ print(f"Error processing chat request: {e}")
495
+ return "Sorry, I encountered an error processing your request. Please try again.", "other"
496
+
497
+ def chat_response(query, history):
498
+ """Process query using RAG and generate response with product images"""
499
+ global messages, product_images, current_product
500
+
501
+ if not query.strip():
502
+ return history
503
+
504
+ # Get context from vector database
505
+ context_chunks = search_vector_db(query)
506
+
507
+ # Get LLM response with RAG
508
+ response_text, detected_product = get_gemini_response(query, context_chunks)
509
+
510
+ # Format new history entry
511
+ new_history = history.copy()
512
+ new_history.append((query, response_text))
513
+
514
+ # Get product images if product detected
515
+ if detected_product != "other":
516
+ current_product = detected_product
517
+ product_images = get_product_images(detected_product)
518
+ else:
519
+ product_images = []
520
+
521
+ # Update daily query data for analytics (in a real app, this would be in a database)
522
+ daily_queries[-1] += 1
523
+
524
+ return new_history
525
+
526
+ def render_images():
527
+ """Render product images as HTML (if available)"""
528
+ if not product_images:
529
+ return ""
530
+
531
+ html = "<div style='margin-top: 12px; display: grid; grid-template-columns: 1fr 1fr; gap: 8px;'>"
532
+ for i, url in enumerate(product_images):
533
+ html += f"""
534
+ <div style='background: #f3f4f6; border-radius: 6px; padding: 8px; text-align: center;'>
535
+ <div style='height: 100px; display: flex; align-items: center; justify-content: center; background: rgba(0,0,0,0.05); border-radius: 4px;'>
536
+ <svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect width="18" height="18" x="3" y="3" rx="2" ry="2"/><circle cx="9" cy="9" r="2"/><path d="m21 15-3.086-3.086a2 2 0 0 0-2.828 0L6 21"/></svg>
537
+ </div>
538
+ <p style='margin-top: 4px; font-size: 12px;'>{url}</p>
539
+ </div>
540
+ """
541
+ html += "</div>"
542
+ return html
543
+
544
+ def render_product_distribution_chart():
545
+ """Render product distribution chart using Plotly"""
546
+ # Create a pie chart for product category distribution
547
+ categories = list(query_counts.keys())
548
+ values = list(query_counts.values())
549
+
550
+ fig = go.Figure(data=[go.Pie(
551
+ labels=categories,
552
+ values=values,
553
+ hole=.3,
554
+ marker_colors=['#3b82f6', '#60a5fa', '#93c5fd', '#bfdbfe', '#dbeafe', '#f1f5f9']
555
+ )])
556
+
557
+ fig.update_layout(
558
+ title="Product Query Distribution",
559
+ margin=dict(t=40, b=20, l=20, r=20),
560
+ height=300,
561
+ legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01, orientation="h")
562
+ )
563
+
564
+ return fig
565
+
566
+ def render_query_volume_chart():
567
+ """Render query volume chart using Plotly"""
568
+ # Create a line chart for query volume over time
569
+ days = list(range(1, len(daily_queries) + 1))
570
+
571
+ fig = go.Figure()
572
+ fig.add_trace(go.Scatter(
573
+ x=days,
574
+ y=daily_queries,
575
+ mode='lines+markers',
576
+ name='Queries',
577
+ line=dict(color='#3b82f6', width=2),
578
+ marker=dict(color='#3b82f6', size=8)
579
+ ))
580
+
581
+ fig.update_layout(
582
+ title="Daily Query Volume",
583
+ xaxis_title="Day",
584
+ yaxis_title="Number of Queries",
585
+ margin=dict(t=40, b=20, l=20, r=20),
586
+ height=300
587
+ )
588
+
589
+ return fig
590
+
591
+ def render_metrics():
592
+ """Render system metrics for the analytics tab with Plotly charts"""
593
+ # Create metrics display with interactive charts
594
+
595
+ # For system metrics section, use HTML
596
+ html = """
597
+ <div style='padding: 16px;'>
598
+ <h3 style='margin-bottom: 16px; font-size: 18px;'>System Metrics</h3>
599
+
600
+ <div style='display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 16px; margin-bottom: 24px;'>
601
+ <div style='background: #f3f4f6; border-radius: 8px; padding: 16px;'>
602
+ <h4 style='font-size: 16px; margin-bottom: 8px; display: flex; align-items: center;'>
603
+ <svg style='margin-right: 8px;' xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><path d="M14 2v6h6"/><path d="M16 13H8"/><path d="M16 17H8"/><path d="M10 9H8"/></svg>
604
+ Document Processing
605
+ </h4>
606
+ <p style='font-size: 14px; color: #6b7280;'>4 PDF catalogs processed</p>
607
+ <p style='font-size: 14px; color: #6b7280;'>1,248 text chunks extracted</p>
608
+ <p style='font-size: 14px; color: #6b7280;'>136 images extracted</p>
609
+ </div>
610
+
611
+ <div style='background: #f3f4f6; border-radius: 8px; padding: 16px;'>
612
+ <h4 style='font-size: 16px; margin-bottom: 8px; display: flex; align-items: center;'>
613
+ <svg style='margin-right: 8px;' xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 18V6M7 10l5-4 5 4M7 14l5 4 5-4"/></svg>
614
+ Vector Database
615
+ </h4>
616
+ <p style='font-size: 14px; color: #6b7280;'>Astra DB connected</p>
617
+ <p style='font-size: 14px; color: #6b7280;'>1,248 text vectors stored</p>
618
+ <p style='font-size: 14px; color: #6b7280;'>136 product images stored</p>
619
+ </div>
620
+
621
+ <div style='background: #f3f4f6; border-radius: 8px; padding: 16px;'>
622
+ <h4 style='font-size: 16px; margin-bottom: 8px; display: flex; align-items: center;'>
623
+ <svg style='margin-right: 8px;' xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 8V4H8"/><rect width="16" height="12" x="4" y="8" rx="2"/><path d="M2 14h2"/><path d="M20 14h2"/><path d="M15 13v2"/><path d="M9 13v2"/></svg>
624
+ LLM Model
625
+ </h4>
626
+ <p style='font-size: 14px; color: #6b7280;'>Using: Gemini 2.0 Flash</p>
627
+ <p style='font-size: 14px; color: #6b7280;'>Vision: Gemini 1.0 Pro Vision</p>
628
+ <p style='font-size: 14px; color: #6b7280;'>Embeddings: VertexAI Embeddings</p>
629
+ <p style='font-size: 14px; color: #6b7280;'>Using Service Account Auth</p>
630
+ </div>
631
+ </div>
632
+ </div>
633
+ """
634
+
635
+ return html
636
+
637
+ def render_advanced_pdf_ingestion():
638
+ """UI for PDF catalog ingestion from S3"""
639
+ html = """
640
+ <div style='padding: 16px;'>
641
+ <h3 style='margin-bottom: 16px; font-size: 18px;'>PDF Catalog Ingestion</h3>
642
+ <p style='margin-bottom: 16px; color: #6b7280;'>
643
+ Upload ABB product catalogs to S3 and process them for the knowledge base.
644
+ </p>
645
+
646
+ <div style='background: #f3f4f6; border-radius: 8px; padding: 16px; margin-bottom: 16px;'>
647
+ <h4 style='font-size: 16px; margin-bottom: 8px;'>Current Status</h4>
648
+ <ul style='list-style: disc; margin-left: 24px;'>
649
+ <li style='margin-bottom: 4px;'>Connected to S3 bucket: <span style='font-weight: 500;'>abb-product-catalogs</span></li>
650
+ <li style='margin-bottom: 4px;'>4 catalogs processed</li>
651
+ <li style='margin-bottom: 4px;'>1,248 text chunks extracted and stored</li>
652
+ <li style='margin-bottom: 4px;'>136 product images extracted and stored</li>
653
+ <li style='margin-bottom: 4px;'>Last processed: March 8, 2025</li>
654
+ </ul>
655
+ </div>
656
+
657
+ <div style='display: grid; grid-template-columns: 1fr 1fr; gap: 16px;'>
658
+ <div style='background: #f3f4f6; border-radius: 8px; padding: 16px;'>
659
+ <h4 style='font-size: 16px; margin-bottom: 8px;'>Available Catalogs</h4>
660
+ <table style='width: 100%; border-collapse: collapse;'>
661
+ <thead>
662
+ <tr style='border-bottom: 1px solid #d1d5db;'>
663
+ <th style='text-align: left; padding: 8px 4px;'>Filename</th>
664
+ <th style='text-align: left; padding: 8px 4px;'>Size</th>
665
+ <th style='text-align: left; padding: 8px 4px;'>Status</th>
666
+ </tr>
667
+ </thead>
668
+ <tbody>
669
+ <tr style='border-bottom: 1px solid #d1d5db;'>
670
+ <td style='padding: 8px 4px;'>circuit_breaker_catalog.pdf</td>
671
+ <td style='padding: 8px 4px;'>4.2 MB</td>
672
+ <td style='padding: 8px 4px;'><span style='color: #059669;'>Processed</span></td>
673
+ </tr>
674
+ <tr style='border-bottom: 1px solid #d1d5db;'>
675
+ <td style='padding: 8px 4px;'>motor_starter_catalog.pdf</td>
676
+ <td style='padding: 8px 4px;'>3.8 MB</td>
677
+ <td style='padding: 8px 4px;'><span style='color: #059669;'>Processed</span></td>
678
+ </tr>
679
+ <tr style='border-bottom: 1px solid #d1d5db;'>
680
+ <td style='padding: 8px 4px;'>contactor_catalog.pdf</td>
681
+ <td style='padding: 8px 4px;'>2.7 MB</td>
682
+ <td style='padding: 8px 4px;'><span style='color: #059669;'>Processed</span></td>
683
+ </tr>
684
+ <tr style='border-bottom: 1px solid #d1d5db;'>
685
+ <td style='padding: 8px 4px;'>relay_catalog.pdf</td>
686
+ <td style='padding: 8px 4px;'>1.9 MB</td>
687
+ <td style='padding: 8px 4px;'><span style='color: #059669;'>Processed</span></td>
688
+ </tr>
689
+ <tr>
690
+ <td style='padding: 8px 4px;'>switch_catalog_2024.pdf</td>
691
+ <td style='padding: 8px 4px;'>3.1 MB</td>
692
+ <td style='padding: 8px 4px;'><span style='color: #dc2626;'>Not Processed</span></td>
693
+ </tr>
694
+ </tbody>
695
+ </table>
696
+ </div>
697
+
698
+ <div style='background: #f3f4f6; border-radius: 8px; padding: 16px;'>
699
+ <h4 style='font-size: 16px; margin-bottom: 16px;'>Process Catalogs</h4>
700
+ <button id="process-btn" style='background: #3b82f6; color: white; padding: 8px 16px; border: none; border-radius: 4px; cursor: pointer; font-weight: 500;'>
701
+ Process All Catalogs
702
+ </button>
703
+ <p style='margin-top: 16px; color: #6b7280; font-size: 14px;'>
704
+ This will process all PDF catalogs in the S3 bucket, extract text and images,
705
+ generate embeddings, and store them in the vector database.
706
+ </p>
707
+ </div>
708
+ </div>
709
+ </div>
710
+ """
711
+
712
+ return html
713
+
714
+ # For the image extraction and serving part, we need to add a function to temporarily store and serve images
715
+ def serve_product_image(image_id):
716
+ """Retrieve an image from Astra DB and serve it temporarily"""
717
+ if not astra_session:
718
+ return None
719
+
720
+ try:
721
+ # Query Astra DB for the specific image
722
+ query = f"""
723
+ SELECT image_data, metadata
724
+ FROM {astra_keyspace}.product_images
725
+ WHERE id = %s
726
+ """
727
+
728
+ rows = astra_session.execute(query, (image_id,))
729
+
730
+ # Get the first matching row
731
+ for row in rows:
732
+ image_data = row.image_data
733
+ metadata = json.loads(row.metadata)
734
+
735
+ # Create a temporary file to serve
736
+ temp_dir = os.path.join(os.getcwd(), "temp_images")
737
+ os.makedirs(temp_dir, exist_ok=True)
738
+
739
+ # Create a filename with the mime type
740
+ mime_type = metadata.get("mime_type", "jpg")
741
+ temp_file = os.path.join(temp_dir, f"{image_id}.{mime_type}")
742
+
743
+ # Write the image to the temporary file
744
+ with open(temp_file, "wb") as f:
745
+ f.write(image_data)
746
+
747
+ # Return the temporary file path
748
+ return temp_file
749
+ except Exception as e:
750
+ print(f"Error serving product image: {e}")
751
+ return None
752
+
753
+ # Update the get_product_images function to use the temporary file paths
754
+ def get_product_images(product):
755
+ """Get product images from Astra DB and return temporary file paths"""
756
+ global product_images
757
+
758
+ if not astra_session:
759
+ return []
760
+
761
+ try:
762
+ # Query Astra DB for images related to the product
763
+ query = f"""
764
+ SELECT id, product_type, metadata
765
+ FROM {astra_keyspace}.product_images
766
+ WHERE product_type = %s
767
+ LIMIT 4
768
+ """
769
+
770
+ rows = astra_session.execute(query, (product,))
771
+
772
+ # Store image paths for display
773
+ image_paths = []
774
+ for row in rows:
775
+ # Get the image ID and serve it
776
+ image_id = row.id
777
+ temp_file = serve_product_image(image_id)
778
+
779
+ if temp_file:
780
+ # Use relative path for serving in the UI
781
+ rel_path = os.path.relpath(temp_file, os.getcwd())
782
+ image_paths.append(rel_path)
783
+
784
+ # If no images found, use placeholder paths
785
+ if not image_paths:
786
+ # Create directory for placeholder images if it doesn't exist
787
+ placeholder_dir = os.path.join(os.getcwd(), "placeholder_images")
788
+ os.makedirs(placeholder_dir, exist_ok=True)
789
+
790
+ # Create placeholder images
791
+ for i in range(2):
792
+ placeholder_file = os.path.join(
793
+ placeholder_dir,
794
+ f"placeholder-{product.lower().replace(' ', '-')}-{i+1}.jpg"
795
+ )
796
+ # Create a simple placeholder image if it doesn't exist
797
+ if not os.path.exists(placeholder_file):
798
+ # Generate a simple colored rectangle as placeholder
799
+ from PIL import Image, ImageDraw, ImageFont
800
+ img = Image.new('RGB', (400, 300), color=(240, 240, 240))
801
+ d = ImageDraw.Draw(img)
802
+ d.rectangle([(0, 0), (400, 300)], outline=(200, 200, 200))
803
+ try:
804
+ font = ImageFont.truetype("arial.ttf", 20)
805
+ except IOError:
806
+ font = ImageFont.load_default()
807
+
808
+ d.text((120, 120), f"ABB {product}", fill=(100, 100, 100), font=font)
809
+ img.save(placeholder_file)
810
+
811
+ image_paths.append(os.path.relpath(placeholder_file, os.getcwd()))
812
+
813
+ return image_paths
814
+ except Exception as e:
815
+ print(f"Error retrieving product images: {e}")
816
+ return []
817
+
818
+ # Update the render_images function to display actual images
819
+ def render_images():
820
+ """Render product images as HTML (if available)"""
821
+ if not product_images:
822
+ return ""
823
+
824
+ html = "<div style='margin-top: 12px; display: grid; grid-template-columns: 1fr 1fr; gap: 8px;'>"
825
+ for i, image_path in enumerate(product_images):
826
+ # Convert backslashes to forward slashes for URLs
827
+ url_path = image_path.replace("\\", "/")
828
+ html += f"""
829
+ <div style='background: #f3f4f6; border-radius: 6px; padding: 8px; text-align: center;'>
830
+ <div style='height: 180px; display: flex; align-items: center; justify-content: center; background: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;'>
831
+ <img src="/{url_path}" alt="Product Image {i+1}" style="max-width: 100%; max-height: 160px; object-fit: contain;">
832
+ </div>
833
+ <p style='margin-top: 4px; font-size: 12px; text-overflow: ellipsis; overflow: hidden; white-space: nowrap;'>{os.path.basename(image_path)}</p>
834
+ </div>
835
+ """
836
+ html += "</div>"
837
+ return html
838
+
839
+ # Setup cleanup function to remove temporary image files
840
+ def cleanup_temp_files():
841
+ """Clean up temporary image files that are older than 1 hour"""
842
+ try:
843
+ temp_dirs = ["temp_images", "placeholder_images"]
844
+ current_time = time.time()
845
+
846
+ for dir_name in temp_dirs:
847
+ if os.path.exists(dir_name):
848
+ for filename in os.listdir(dir_name):
849
+ file_path = os.path.join(dir_name, filename)
850
+ # Check if the file is older than 1 hour
851
+ if os.path.isfile(file_path) and (current_time - os.path.getmtime(file_path) > 3600):
852
+ os.remove(file_path)
853
+ except Exception as e:
854
+ print(f"Error cleaning up temporary files: {e}")
855
+
856
+ # Schedule periodic cleanup of temporary files
857
+ def schedule_cleanup():
858
+ """Schedule periodic cleanup of temporary files"""
859
+ import threading
860
+
861
+ # Run cleanup
862
+ cleanup_temp_files()
863
+
864
+ # Schedule next cleanup in 30 minutes
865
+ threading.Timer(1800, schedule_cleanup).start()
866
+
867
+ # Initialize Gemini API, Astra DB, S3 client, and embedding model
868
+ gemini_initialized = init_gemini_api()
869
+ astra_session, astra_keyspace = init_astra_db()
870
+ s3_client = init_s3_client()
871
+ embeddings_model = get_embeddings_model()
872
+
873
+ # Initialize main UI
874
+ def create_ui():
875
+ """Create the main Gradio UI with tabs for chat, analytics, and admin"""
876
+ with gr.Blocks(title="ABB Product Assistant", css="") as demo:
877
+ gr.Markdown("# ABB Product Assistant")
878
+
879
+ with gr.Tabs() as tabs:
880
+ # Chat tab
881
+ with gr.TabItem("Chat"):
882
+ chatbot = gr.Chatbot(value=[], elem_id="chatbot")
883
+ with gr.Row():
884
+ msg = gr.Textbox(placeholder="Ask about ABB products...", scale=4)
885
+ submit = gr.Button("Send", scale=1)
886
+
887
+ gr.HTML(render_images, elem_id="product-images")
888
+
889
+ # Set up chat functionality
890
+ submit.click(
891
+ chat_response,
892
+ [msg, chatbot],
893
+ [chatbot],
894
+ queue=False
895
+ ).then(
896
+ lambda: "",
897
+ None,
898
+ [msg],
899
+ queue=False
900
+ )
901
+
902
+ msg.submit(
903
+ chat_response,
904
+ [msg, chatbot],
905
+ [chatbot],
906
+ queue=False
907
+ ).then(
908
+ lambda: "",
909
+ None,
910
+ [msg],
911
+ queue=False
912
+ )
913
+
914
+ # Analytics tab
915
+ with gr.TabItem("Analytics"):
916
+ gr.HTML(render_metrics)
917
+
918
+ with gr.Row():
919
+ with gr.Column():
920
+ gr.Plot(render_product_distribution_chart)
921
+ with gr.Column():
922
+ gr.Plot(render_query_volume_chart)
923
+
924
+ # Admin tab
925
+ with gr.TabItem("Admin"):
926
+ gr.HTML(render_advanced_pdf_ingestion)
927
+
928
+ return demo
929
+
930
+ # Start the application
931
+ if __name__ == "__main__":
932
+ # Schedule cleanup of temporary files
933
+ schedule_cleanup()
934
+
935
+ # Create and launch the UI
936
+ demo = create_ui()
937
+ demo.launch(share=True)