prernajeet01 commited on
Commit
87e2dbe
·
verified ·
1 Parent(s): e91d17a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +378 -453
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import gradio as gr
3
- import google.generativeai as genai
4
  import pandas as pd
5
  import plotly.express as px
6
  import plotly.graph_objects as go
@@ -13,14 +12,15 @@ import re
13
  import time
14
  import numpy as np
15
  import pdfplumber
 
16
  from dotenv import load_dotenv
17
  from cassandra.cluster import Cluster
18
  from cassandra.auth import PlainTextAuthProvider
19
  from cassandra.query import SimpleStatement
20
  from langchain.text_splitter import RecursiveCharacterTextSplitter
21
  from langchain_community.vectorstores import Cassandra
22
- from langchain_community.embeddings import VertexAIEmbeddings
23
- from google.oauth2 import service_account
24
 
25
  # Load environment variables
26
  load_dotenv()
@@ -32,46 +32,43 @@ current_product = ""
32
  query_counts = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0, "other": 0}
33
  daily_queries = [0, 0, 0, 0, 0, 6, 8, 10, 7, 9, 12, 15, 11, 14] # Mock data for chart
34
 
35
- # Initialize Gemini API with service account credentials
36
- def init_gemini_api():
37
- """Initialize Google Gemini API with service account credentials from Hugging Face Secrets"""
38
  try:
39
- # Retrieve service account JSON from Hugging Face Secrets
40
- service_account_json = os.getenv("SERVICE_ACCOUNT_JSON")
41
- if not service_account_json:
42
- raise ValueError("SERVICE_ACCOUNT_JSON is not set in environment variables")
43
-
44
- # Convert the single-line string back to JSON format
45
- service_account_dict = json.loads(service_account_json)
46
-
47
- # Write it to a temporary JSON file
48
- credentials_path = "service_account.json"
49
- with open(credentials_path, "w") as f:
50
- json.dump(service_account_dict, f)
51
-
52
- # Load credentials from the temporary file
53
- credentials = service_account.Credentials.from_service_account_file(
54
- credentials_path,
55
- scopes=["https://www.googleapis.com/auth/cloud-platform"]
56
- )
57
 
58
- # Configure Gemini API with credentials
59
- genai.configure(credentials=credentials)
60
- print("Gemini API initialized with service account credentials from Hugging Face Secrets")
61
  return True
62
 
63
  except Exception as e:
64
- print(f"Error initializing Gemini API with service account: {e}")
65
-
66
- # Fallback to API key method if service account fails
67
- try:
68
- genai.configure(api_key=os.getenv("GEMINI_API_KEY", ""))
69
- print("Gemini API initialized with API key")
70
- return True
71
-
72
- except Exception as e2:
73
- print(f"Fallback to API key also failed: {e2}")
 
74
  return False
 
 
 
 
 
 
 
 
 
75
 
76
  # Initialize Astra DB connection
77
  def init_astra_db():
@@ -157,11 +154,11 @@ def init_s3_client():
157
 
158
  # Initialize embedding model
159
  def get_embeddings_model():
160
- """Initialize the embeddings model for vector generation"""
161
  try:
162
- embeddings = VertexAIEmbeddings(
163
- project=os.getenv("GOOGLE_CLOUD_PROJECT"),
164
- location=os.getenv("GOOGLE_CLOUD_LOCATION")
165
  )
166
  return embeddings
167
  except Exception as e:
@@ -428,38 +425,96 @@ def get_product_images(product):
428
  print(f"Error retrieving product images: {e}")
429
  return []
430
 
431
- # Analyze product image with Gemini Vision
432
- def analyze_product_image_with_vision(image_data, query):
433
- """Analyze product image using Gemini Pro Vision"""
434
- if not image_data:
435
- return "No image data available for analysis"
436
 
437
  try:
438
- # Use Gemini 1.0 Pro Vision model
439
- model_name = "gemini-1.0-pro-vision-001"
440
- model = genai.GenerativeModel(model_name)
441
-
442
- # Create a vision-enabled prompt
443
- response = model.generate_content([
444
- "Analyze this ABB product image and answer the following question:",
445
- query,
446
- genai.types.Part.from_data(image_data, mime_type="image/jpeg")
447
- ])
448
-
449
- return response.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  except Exception as e:
451
- print(f"Error analyzing image with Gemini Vision: {e}")
452
- return "Error analyzing image. Please try a different query."
 
 
 
 
453
 
454
- def get_gemini_response(query, context_chunks=None):
455
- """Get enhanced response from Gemini model using RAG"""
 
456
  start_time = time.time()
457
 
458
  try:
459
- # Set up the model
460
- model_name = "gemini-2.0-flash-001"
461
- model = genai.GenerativeModel(model_name)
462
-
463
  # Detect product type from query
464
  product_keywords = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0}
465
  detected_product = "other"
@@ -489,8 +544,34 @@ def get_gemini_response(query, context_chunks=None):
489
  User query: {query}
490
  """
491
 
492
- # Generate response using Gemini
493
- response = model.generate_content(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
 
495
  # Update query counts for analytics
496
  if detected_product in query_counts:
@@ -502,12 +583,12 @@ def get_gemini_response(query, context_chunks=None):
502
  response_time = time.time() - start_time
503
  log_query_analytics(query, detected_product, response_time)
504
 
505
- return response.text, detected_product
506
  except Exception as e:
507
- print(f"Error processing chat request: {e}")
508
  return "Sorry, I encountered an error processing your request. Please try again.", "other"
509
 
510
- def chat_response(query, history):
511
  """Process query using RAG and generate response with product images"""
512
  global messages, product_images, current_product
513
 
@@ -517,8 +598,12 @@ def chat_response(query, history):
517
  # Get context from vector database
518
  context_chunks = search_vector_db(query)
519
 
520
- # Get LLM response with RAG
521
- response_text, detected_product = get_gemini_response(query, context_chunks)
 
 
 
 
522
 
523
  # Format new history entry
524
  new_history = history.copy()
@@ -536,6 +621,24 @@ def chat_response(query, history):
536
 
537
  return new_history
538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  def render_images():
540
  """Render product images as HTML (if available)"""
541
  if not product_images:
@@ -554,397 +657,219 @@ def render_images():
554
  html += "</div>"
555
  return html
556
 
557
- def render_product_distribution_chart():
558
- """Render product distribution chart using Plotly"""
559
- # Create a pie chart for product category distribution
560
- categories = list(query_counts.keys())
561
- values = list(query_counts.values())
562
-
563
- fig = go.Figure(data=[go.Pie(
564
- labels=categories,
565
- values=values,
566
- hole=.3,
567
- marker_colors=['#3b82f6', '#60a5fa', '#93c5fd', '#bfdbfe', '#dbeafe', '#f1f5f9']
568
- )])
569
-
570
- fig.update_layout(
571
- title="Product Query Distribution",
572
- margin=dict(t=40, b=20, l=20, r=20),
573
- height=300,
574
- legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01, orientation="h")
575
- )
576
-
577
- return fig
578
 
579
- def render_query_volume_chart():
580
- """Render query volume chart using Plotly"""
581
- # Create a line chart for query volume over time
582
- days = list(range(1, len(daily_queries) + 1))
583
-
584
- fig = go.Figure()
585
- fig.add_trace(go.Scatter(
586
- x=days,
587
- y=daily_queries,
588
- mode='lines+markers',
589
- name='Queries',
590
- line=dict(color='#3b82f6', width=2),
591
- marker=dict(color='#3b82f6', size=8)
592
- ))
593
-
594
- fig.update_layout(
595
- title="Daily Query Volume",
596
- xaxis_title="Day",
597
- yaxis_title="Number of Queries",
598
- margin=dict(t=40, b=20, l=20, r=20),
599
- height=300
600
- )
601
-
602
- return fig
603
 
604
- def render_metrics():
605
- """Render system metrics for the analytics tab with Plotly charts"""
606
- # Create metrics display with interactive charts
607
-
608
- # For system metrics section, use HTML
609
- html = """
610
- <div style='padding: 16px;'>
611
- <h3 style='margin-bottom: 16px; font-size: 18px;'>System Metrics</h3>
612
-
613
- <div style='display: grid; grid-template-columns: 1fr 1fr 1fr; gap: 16px; margin-bottom: 24px;'>
614
- <div style='background: #f3f4f6; border-radius: 8px; padding: 16px;'>
615
- <h4 style='font-size: 16px; margin-bottom: 8px; display: flex; align-items: center;'>
616
- <svg style='margin-right: 8px;' xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><path d="M14 2v6h6"/><path d="M16 13H8"/><path d="M16 17H8"/><path d="M10 9H8"/></svg>
617
- Document Processing
618
- </h4>
619
- <p style='font-size: 14px; color: #6b7280;'>4 PDF catalogs processed</p>
620
- <p style='font-size: 14px; color: #6b7280;'>1,248 text chunks extracted</p>
621
- <p style='font-size: 14px; color: #6b7280;'>136 images extracted</p>
622
- </div>
623
-
624
- <div style='background: #f3f4f6; border-radius: 8px; padding: 16px;'>
625
- <h4 style='font-size: 16px; margin-bottom: 8px; display: flex; align-items: center;'>
626
- <svg style='margin-right: 8px;' xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 18V6M7 10l5-4 5 4M7 14l5 4 5-4"/></svg>
627
- Vector Database
628
- </h4>
629
- <p style='font-size: 14px; color: #6b7280;'>Astra DB connected</p>
630
- <p style='font-size: 14px; color: #6b7280;'>1,248 text vectors stored</p>
631
- <p style='font-size: 14px; color: #6b7280;'>136 product images stored</p>
632
- </div>
633
-
634
- <div style='background: #f3f4f6; border-radius: 8px; padding: 16px;'>
635
- <h4 style='font-size: 16px; margin-bottom: 8px; display: flex; align-items: center;'>
636
- <svg style='margin-right: 8px;' xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 8V4H8"/><rect width="16" height="12" x="4" y="8" rx="2"/><path d="M2 14h2"/><path d="M20 14h2"/><path d="M15 13v2"/><path d="M9 13v2"/></svg>
637
- LLM Model
638
- </h4>
639
- <p style='font-size: 14px; color: #6b7280;'>Using: Gemini 2.0 Flash</p>
640
- <p style='font-size: 14px; color: #6b7280;'>Vision: Gemini 1.0 Pro Vision</p>
641
- <p style='font-size: 14px; color: #6b7280;'>Embeddings: VertexAI Embeddings</p>
642
- <p style='font-size: 14px; color: #6b7280;'>Using Service Account Auth</p>
643
- </div>
644
- </div>
645
- </div>
646
- """
647
-
648
- return html
649
 
650
- def render_advanced_pdf_ingestion():
651
- """UI for PDF catalog ingestion from S3"""
652
- html = """
653
- <div style='padding: 16px;'>
654
- <h3 style='margin-bottom: 16px; font-size: 18px;'>PDF Catalog Ingestion</h3>
655
- <p style='margin-bottom: 16px; color: #6b7280;'>
656
- Upload ABB product catalogs to S3 and process them for the knowledge base.
657
- </p>
658
-
659
- <div style='background: #f3f4f6; border-radius: 8px; padding: 16px; margin-bottom: 16px;'>
660
- <h4 style='font-size: 16px; margin-bottom: 8px;'>Current Status</h4>
661
- <ul style='list-style: disc; margin-left: 24px;'>
662
- <li style='margin-bottom: 4px;'>Connected to S3 bucket: <span style='font-weight: 500;'>abb-product-catalogs</span></li>
663
- <li style='margin-bottom: 4px;'>4 catalogs processed</li>
664
- <li style='margin-bottom: 4px;'>1,248 text chunks extracted and stored</li>
665
- <li style='margin-bottom: 4px;'>136 product images extracted and stored</li>
666
- <li style='margin-bottom: 4px;'>Last processed: March 8, 2025</li>
667
- </ul>
668
- </div>
669
-
670
- <div style='display: grid; grid-template-columns: 1fr 1fr; gap: 16px;'>
671
- <div style='background: #f3f4f6; border-radius: 8px; padding: 16px;'>
672
- <h4 style='font-size: 16px; margin-bottom: 8px;'>Available Catalogs</h4>
673
- <table style='width: 100%; border-collapse: collapse;'>
674
- <thead>
675
- <tr style='border-bottom: 1px solid #d1d5db;'>
676
- <th style='text-align: left; padding: 8px 4px;'>Filename</th>
677
- <th style='text-align: left; padding: 8px 4px;'>Size</th>
678
- <th style='text-align: left; padding: 8px 4px;'>Status</th>
679
- </tr>
680
- </thead>
681
- <tbody>
682
- <tr style='border-bottom: 1px solid #d1d5db;'>
683
- <td style='padding: 8px 4px;'>circuit_breaker_catalog.pdf</td>
684
- <td style='padding: 8px 4px;'>4.2 MB</td>
685
- <td style='padding: 8px 4px;'><span style='color: #059669;'>Processed</span></td>
686
- </tr>
687
- <tr style='border-bottom: 1px solid #d1d5db;'>
688
- <td style='padding: 8px 4px;'>motor_starter_catalog.pdf</td>
689
- <td style='padding: 8px 4px;'>3.8 MB</td>
690
- <td style='padding: 8px 4px;'><span style='color: #059669;'>Processed</span></td>
691
- </tr>
692
- <tr style='border-bottom: 1px solid #d1d5db;'>
693
- <td style='padding: 8px 4px;'>contactor_catalog.pdf</td>
694
- <td style='padding: 8px 4px;'>2.7 MB</td>
695
- <td style='padding: 8px 4px;'><span style='color: #059669;'>Processed</span></td>
696
- </tr>
697
- <tr style='border-bottom: 1px solid #d1d5db;'>
698
- <td style='padding: 8px 4px;'>relay_catalog.pdf</td>
699
- <td style='padding: 8px 4px;'>1.9 MB</td>
700
- <td style='padding: 8px 4px;'><span style='color: #059669;'>Processed</span></td>
701
- </tr>
702
- <tr>
703
- <td style='padding: 8px 4px;'>switch_catalog_2024.pdf</td>
704
- <td style='padding: 8px 4px;'>3.1 MB</td>
705
- <td style='padding: 8px 4px;'><span style='color: #dc2626;'>Not Processed</span></td>
706
- </tr>
707
- </tbody>
708
- </table>
709
- </div>
710
-
711
- <div style='background: #f3f4f6; border-radius: 8px; padding: 16px;'>
712
- <h4 style='font-size: 16px; margin-bottom: 16px;'>Process Catalogs</h4>
713
- <button id="process-btn" style='background: #3b82f6; color: white; padding: 8px 16px; border: none; border-radius: 4px; cursor: pointer; font-weight: 500;'>
714
- Process All Catalogs
715
- </button>
716
- <p style='margin-top: 16px; color: #6b7280; font-size: 14px;'>
717
- This will process all PDF catalogs in the S3 bucket, extract text and images,
718
- generate embeddings, and store them in the vector database.
719
- </p>
720
- </div>
721
- </div>
722
- </div>
723
  """
724
-
725
- return html
726
 
727
- # For the image extraction and serving part, we need to add a function to temporarily store and serve images
728
- def serve_product_image(image_id):
729
- """Retrieve an image from Astra DB and serve it temporarily"""
730
- if not astra_session:
731
- return None
732
-
733
- try:
734
- # Query Astra DB for the specific image
735
- query = f"""
736
- SELECT image_data, metadata
737
- FROM {astra_keyspace}.product_images
738
- WHERE id = %s
739
- """
740
-
741
- rows = astra_session.execute(query, (image_id,))
742
-
743
- # Get the first matching row
744
- for row in rows:
745
- image_data = row.image_data
746
- metadata = json.loads(row.metadata)
747
-
748
- # Create a temporary file to serve
749
- temp_dir = os.path.join(os.getcwd(), "temp_images")
750
- os.makedirs(temp_dir, exist_ok=True)
751
-
752
- # Create a filename with the mime type
753
- mime_type = metadata.get("mime_type", "jpg")
754
- temp_file = os.path.join(temp_dir, f"{image_id}.{mime_type}")
755
-
756
- # Write the image to the temporary file
757
- with open(temp_file, "wb") as f:
758
- f.write(image_data)
759
-
760
- # Return the temporary file path
761
- return temp_file
762
- except Exception as e:
763
- print(f"Error serving product image: {e}")
764
- return None
765
 
766
- # Update the get_product_images function to use the temporary file paths
767
- def get_product_images(product):
768
- """Get product images from Astra DB and return temporary file paths"""
769
- global product_images
770
-
771
- if not astra_session:
772
- return []
773
-
774
- try:
775
- # Query Astra DB for images related to the product
776
- query = f"""
777
- SELECT id, product_type, metadata
778
- FROM {astra_keyspace}.product_images
779
- WHERE product_type = %s
780
- LIMIT 4
781
- """
782
-
783
- rows = astra_session.execute(query, (product,))
784
-
785
- # Store image paths for display
786
- image_paths = []
787
- for row in rows:
788
- # Get the image ID and serve it
789
- image_id = row.id
790
- temp_file = serve_product_image(image_id)
791
-
792
- if temp_file:
793
- # Use relative path for serving in the UI
794
- rel_path = os.path.relpath(temp_file, os.getcwd())
795
- image_paths.append(rel_path)
796
-
797
- # If no images found, use placeholder paths
798
- if not image_paths:
799
- # Create directory for placeholder images if it doesn't exist
800
- placeholder_dir = os.path.join(os.getcwd(), "placeholder_images")
801
- os.makedirs(placeholder_dir, exist_ok=True)
802
-
803
- # Create placeholder images
804
- for i in range(2):
805
- placeholder_file = os.path.join(
806
- placeholder_dir,
807
- f"placeholder-{product.lower().replace(' ', '-')}-{i+1}.jpg"
808
- )
809
- # Create a simple placeholder image if it doesn't exist
810
- if not os.path.exists(placeholder_file):
811
- # Generate a simple colored rectangle as placeholder
812
- from PIL import Image, ImageDraw, ImageFont
813
- img = Image.new('RGB', (400, 300), color=(240, 240, 240))
814
- d = ImageDraw.Draw(img)
815
- d.rectangle([(0, 0), (400, 300)], outline=(200, 200, 200))
816
- try:
817
- font = ImageFont.truetype("arial.ttf", 20)
818
- except IOError:
819
- font = ImageFont.load_default()
820
-
821
- d.text((120, 120), f"ABB {product}", fill=(100, 100, 100), font=font)
822
- img.save(placeholder_file)
823
-
824
- image_paths.append(os.path.relpath(placeholder_file, os.getcwd()))
825
-
826
- return image_paths
827
- except Exception as e:
828
- print(f"Error retrieving product images: {e}")
829
- return []
830
 
831
- # Update the render_images function to display actual images
832
- def render_images():
833
- """Render product images as HTML (if available)"""
834
- if not product_images:
835
- return ""
836
-
837
- html = "<div style='margin-top: 12px; display: grid; grid-template-columns: 1fr 1fr; gap: 8px;'>"
838
- for i, image_path in enumerate(product_images):
839
- # Convert backslashes to forward slashes for URLs
840
- url_path = image_path.replace("\\", "/")
841
- html += f"""
842
- <div style='background: #f3f4f6; border-radius: 6px; padding: 8px; text-align: center;'>
843
- <div style='height: 180px; display: flex; align-items: center; justify-content: center; background: rgba(0,0,0,0.05); border-radius: 4px; overflow: hidden;'>
844
- <img src="/{url_path}" alt="Product Image {i+1}" style="max-width: 100%; max-height: 160px; object-fit: contain;">
845
- </div>
846
- <p style='margin-top: 4px; font-size: 12px; text-overflow: ellipsis; overflow: hidden; white-space: nowrap;'>{os.path.basename(image_path)}</p>
847
- </div>
848
- """
849
- html += "</div>"
850
- return html
851
 
852
- # Setup cleanup function to remove temporary image files
853
- def cleanup_temp_files():
854
- """Clean up temporary image files that are older than 1 hour"""
855
- try:
856
- temp_dirs = ["temp_images", "placeholder_images"]
857
- current_time = time.time()
858
-
859
- for dir_name in temp_dirs:
860
- if os.path.exists(dir_name):
861
- for filename in os.listdir(dir_name):
862
- file_path = os.path.join(dir_name, filename)
863
- # Check if the file is older than 1 hour
864
- if os.path.isfile(file_path) and (current_time - os.path.getmtime(file_path) > 3600):
865
- os.remove(file_path)
866
- except Exception as e:
867
- print(f"Error cleaning up temporary files: {e}")
868
 
869
- # Schedule periodic cleanup of temporary files
870
- def schedule_cleanup():
871
- """Schedule periodic cleanup of temporary files"""
872
- import threading
873
-
874
- # Run cleanup
875
- cleanup_temp_files()
876
-
877
- # Schedule next cleanup in 30 minutes
878
- threading.Timer(1800, schedule_cleanup).start()
879
-
880
- # Initialize Gemini API, Astra DB, S3 client, and embedding model
881
- gemini_initialized = init_gemini_api()
882
- astra_session, astra_keyspace = init_astra_db()
883
- s3_client = init_s3_client()
884
- embeddings_model = get_embeddings_model()
885
-
886
- # Initialize main UI
887
- def create_ui():
888
- """Create the main Gradio UI with tabs for chat, analytics, and admin"""
889
- with gr.Blocks(title="ABB Product Assistant", css="") as demo:
890
- gr.Markdown("# ABB Product Assistant")
891
-
892
- with gr.Tabs() as tabs:
893
- # Chat tab
894
- with gr.TabItem("Chat"):
895
- chatbot = gr.Chatbot(value=[], elem_id="chatbot")
896
- with gr.Row():
897
- msg = gr.Textbox(placeholder="Ask about ABB products...", scale=4)
898
- submit = gr.Button("Send", scale=1)
899
-
900
- gr.HTML(render_images, elem_id="product-images")
901
-
902
- # Set up chat functionality
903
- submit.click(
904
- chat_response,
905
- [msg, chatbot],
906
- [chatbot],
907
- queue=False
908
- ).then(
909
- lambda: "",
910
- None,
911
- [msg],
912
- queue=False
913
- )
914
-
915
- msg.submit(
916
- chat_response,
917
- [msg, chatbot],
918
- [chatbot],
919
- queue=False
920
- ).then(
921
- lambda: "",
922
- None,
923
- [msg],
924
- queue=False
925
- )
926
-
927
- # Analytics tab
928
- with gr.TabItem("Analytics"):
929
- gr.HTML(render_metrics)
930
-
931
- with gr.Row():
932
- with gr.Column():
933
- gr.Plot(render_product_distribution_chart)
934
- with gr.Column():
935
- gr.Plot(render_query_volume_chart)
936
-
937
- # Admin tab
938
- with gr.TabItem("Admin"):
939
- gr.HTML(render_advanced_pdf_ingestion)
940
-
941
- return demo
942
 
943
  # Start the application
944
  if __name__ == "__main__":
945
- # Schedule cleanup of temporary files
946
- schedule_cleanup()
947
-
948
  # Create and launch the UI
949
- demo = create_ui()
950
  demo.launch(share=True)
 
1
  import os
2
  import gradio as gr
 
3
  import pandas as pd
4
  import plotly.express as px
5
  import plotly.graph_objects as go
 
12
  import time
13
  import numpy as np
14
  import pdfplumber
15
+ import requests
16
  from dotenv import load_dotenv
17
  from cassandra.cluster import Cluster
18
  from cassandra.auth import PlainTextAuthProvider
19
  from cassandra.query import SimpleStatement
20
  from langchain.text_splitter import RecursiveCharacterTextSplitter
21
  from langchain_community.vectorstores import Cassandra
22
+ from langchain_openai import OpenAIEmbeddings
23
+ from PIL import Image, ImageDraw, ImageFont
24
 
25
  # Load environment variables
26
  load_dotenv()
 
32
  query_counts = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0, "other": 0}
33
  daily_queries = [0, 0, 0, 0, 0, 6, 8, 10, 7, 9, 12, 15, 11, 14] # Mock data for chart
34
 
35
+ # Initialize OpenAI API
36
+ def init_openai_api():
37
+ """Initialize OpenAI API with API key from Hugging Face Secrets"""
38
  try:
39
+ # Get API key from environment (set by Hugging Face Secrets)
40
+ openai_api_key = os.getenv("OPENAI_API_KEY")
41
+ if not openai_api_key:
42
+ print("OPENAI_API_KEY is not set in environment variables")
43
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ # Set as environment variable for libraries that use it directly
46
+ os.environ["OPENAI_API_KEY"] = openai_api_key
47
+ print("OpenAI API initialized with API key from Hugging Face Secrets")
48
  return True
49
 
50
  except Exception as e:
51
+ print(f"Error initializing OpenAI API: {e}")
52
+ return False
53
+
54
+ # Initialize Mistral API
55
+ def init_mistral_api():
56
+ """Initialize Mistral API with API key from Hugging Face Secrets"""
57
+ try:
58
+ # Get API key from environment (set by Hugging Face Secrets)
59
+ mistral_api_key = os.getenv("MISTRAL_API_KEY")
60
+ if not mistral_api_key:
61
+ print("MISTRAL_API_KEY is not set in environment variables")
62
  return False
63
+
64
+ # Set as environment variable for libraries that use it directly
65
+ os.environ["MISTRAL_API_KEY"] = mistral_api_key
66
+ print("Mistral API initialized with API key from Hugging Face Secrets")
67
+ return True
68
+
69
+ except Exception as e:
70
+ print(f"Error initializing Mistral API: {e}")
71
+ return False
72
 
73
  # Initialize Astra DB connection
74
  def init_astra_db():
 
154
 
155
  # Initialize embedding model
156
  def get_embeddings_model():
157
+ """Initialize the OpenAI embeddings model for vector generation"""
158
  try:
159
+ embeddings = OpenAIEmbeddings(
160
+ model="text-embedding-ada-002",
161
+ openai_api_key=os.getenv("OPENAI_API_KEY")
162
  )
163
  return embeddings
164
  except Exception as e:
 
425
  print(f"Error retrieving product images: {e}")
426
  return []
427
 
428
+ # Get response from OpenAI API
429
+ def get_openai_response(query, context_chunks=None):
430
+ """Get enhanced response from OpenAI model using RAG"""
431
+ start_time = time.time()
 
432
 
433
  try:
434
+ # Detect product type from query
435
+ product_keywords = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0}
436
+ detected_product = "other"
437
+
438
+ for keyword in product_keywords:
439
+ if keyword in query.lower():
440
+ product_keywords[keyword] += 1
441
+ if product_keywords[keyword] > product_keywords.get(detected_product, -1):
442
+ detected_product = keyword
443
+
444
+ # If no context chunks provided, search the vector DB
445
+ if not context_chunks:
446
+ context_chunks = search_vector_db(query, product_type=detected_product if detected_product != "other" else None)
447
+
448
+ # Build context from retrieved chunks
449
+ context_text = "\n\n".join([chunk["content"] for chunk in context_chunks]) if context_chunks else ""
450
+
451
+ # Create prompt with context
452
+ prompt = f"""
453
+ You are an assistant specialized in ABB products and solutions. Answer the following query about ABB products with accurate and helpful information.
454
+
455
+ Use the following product information to inform your response:
456
+ {context_text}
457
+
458
+ If the information above doesn't contain relevant details, use your general knowledge about industrial electrical equipment, but be clear about what information comes from the ABB catalog versus general knowledge.
459
+
460
+ User query: {query}
461
+ """
462
+
463
+ # Call OpenAI API
464
+ headers = {
465
+ "Content-Type": "application/json",
466
+ "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
467
+ }
468
+
469
+ payload = {
470
+ "model": "gpt-4o",
471
+ "messages": [
472
+ {"role": "system", "content": "You are an assistant specialized in ABB products and solutions."},
473
+ {"role": "user", "content": prompt}
474
+ ],
475
+ "temperature": 0.7,
476
+ "max_tokens": 800
477
+ }
478
+
479
+ response = requests.post(
480
+ "https://api.openai.com/v1/chat/completions",
481
+ headers=headers,
482
+ json=payload
483
+ )
484
+
485
+ if response.status_code == 200:
486
+ response_json = response.json()
487
+ response_text = response_json["choices"][0]["message"]["content"]
488
+ else:
489
+ # Fallback to Mistral if OpenAI fails
490
+ print(f"OpenAI API error: {response.status_code}, {response.text}")
491
+ response_text = get_mistral_response(query, context_chunks)
492
+
493
+ # Update query counts for analytics
494
+ if detected_product in query_counts:
495
+ query_counts[detected_product] += 1
496
+ else:
497
+ query_counts["other"] += 1
498
+
499
+ # Log analytics
500
+ response_time = time.time() - start_time
501
+ log_query_analytics(query, detected_product, response_time)
502
+
503
+ return response_text, detected_product
504
  except Exception as e:
505
+ print(f"Error processing chat request with OpenAI: {e}")
506
+ # Fallback to Mistral
507
+ try:
508
+ return get_mistral_response(query, context_chunks)
509
+ except:
510
+ return "Sorry, I encountered an error processing your request. Please try again.", "other"
511
 
512
+ # Get response from Mistral API (fallback)
513
+ def get_mistral_response(query, context_chunks=None):
514
+ """Get enhanced response from Mistral model using RAG (fallback)"""
515
  start_time = time.time()
516
 
517
  try:
 
 
 
 
518
  # Detect product type from query
519
  product_keywords = {"circuit breaker": 0, "motor starter": 0, "contactor": 0, "switch": 0, "relay": 0}
520
  detected_product = "other"
 
544
  User query: {query}
545
  """
546
 
547
+ # Call Mistral API
548
+ headers = {
549
+ "Content-Type": "application/json",
550
+ "Authorization": f"Bearer {os.getenv('MISTRAL_API_KEY')}"
551
+ }
552
+
553
+ payload = {
554
+ "model": "mistral-large-latest",
555
+ "messages": [
556
+ {"role": "system", "content": "You are an assistant specialized in ABB products and solutions."},
557
+ {"role": "user", "content": prompt}
558
+ ],
559
+ "temperature": 0.7,
560
+ "max_tokens": 800
561
+ }
562
+
563
+ response = requests.post(
564
+ "https://api.mistral.ai/v1/chat/completions",
565
+ headers=headers,
566
+ json=payload
567
+ )
568
+
569
+ if response.status_code == 200:
570
+ response_json = response.json()
571
+ response_text = response_json["choices"][0]["message"]["content"]
572
+ else:
573
+ print(f"Mistral API error: {response.status_code}, {response.text}")
574
+ response_text = "Sorry, I encountered an error processing your request. Please try again."
575
 
576
  # Update query counts for analytics
577
  if detected_product in query_counts:
 
583
  response_time = time.time() - start_time
584
  log_query_analytics(query, detected_product, response_time)
585
 
586
+ return response_text, detected_product
587
  except Exception as e:
588
+ print(f"Error processing chat request with Mistral: {e}")
589
  return "Sorry, I encountered an error processing your request. Please try again.", "other"
590
 
591
+ def process_message(query, history):
592
  """Process query using RAG and generate response with product images"""
593
  global messages, product_images, current_product
594
 
 
598
  # Get context from vector database
599
  context_chunks = search_vector_db(query)
600
 
601
+ # Get LLM response with RAG (try OpenAI first, fallback to Mistral)
602
+ try:
603
+ response_text, detected_product = get_openai_response(query, context_chunks)
604
+ except Exception as e:
605
+ print(f"Error with OpenAI, falling back to Mistral: {e}")
606
+ response_text, detected_product = get_mistral_response(query, context_chunks)
607
 
608
  # Format new history entry
609
  new_history = history.copy()
 
621
 
622
  return new_history
623
 
624
+ def reset_chat(history):
625
+ """Reset the chat history"""
626
+ return []
627
+
628
+ def process_pdfs_from_s3(bucket_name, prefix):
629
+ """Process PDFs from S3 bucket"""
630
+ # Set environment variable for S3 bucket
631
+ os.environ["S3_BUCKET_NAME"] = bucket_name
632
+
633
+ # Process PDFs
634
+ result = process_pdf_catalogs()
635
+
636
+ # Return result as string
637
+ if result["status"] == "success":
638
+ return f"Successfully processed {result['files_processed']} files, {result['chunks_processed']} chunks, and {result['images_processed']} images."
639
+ else:
640
+ return f"Error: {result['message']}"
641
+
642
  def render_images():
643
  """Render product images as HTML (if available)"""
644
  if not product_images:
 
657
  html += "</div>"
658
  return html
659
 
660
+ def setup_and_update():
661
+ """Setup the system and update status"""
662
+ # Initialize APIs
663
+ openai_initialized = init_openai_api()
664
+ mistral_initialized = init_mistral_api()
665
+
666
+ # Initialize database and other services
667
+ global astra_session, astra_keyspace, s3_client, embeddings_model
668
+ astra_session, astra_keyspace = init_astra_db()
669
+ s3_client = init_s3_client()
670
+ embeddings_model = get_embeddings_model()
671
+
672
+ # Return status
673
+ if openai_initialized and mistral_initialized:
674
+ return "System is ready. You can start chatting!"
675
+ else:
676
+ return "System initialization incomplete. Some features may not work properly."
 
 
 
 
677
 
678
+ def create_gradio_app():
679
+ # Define CSS styles for a more modern, appealing interface
680
+ custom_css = """
681
+ :root {
682
+ --primary-color: #FF000C;
683
+ --secondary-color: #212832;
684
+ --background-color: var(--body-background-fill);
685
+ --card-color: var(--block-background-fill);
686
+ --text-color: var(--body-text-color);
687
+ --border-radius: 12px;
688
+ --shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
689
+ }
 
 
 
 
 
 
 
 
 
 
 
 
690
 
691
+ .app-header {
692
+ background-color: var(--secondary-color);
693
+ padding: 20px;
694
+ border-radius: var(--border-radius);
695
+ margin-bottom: 20px;
696
+ box-shadow: var(--shadow);
697
+ display: flex;
698
+ align-items: center;
699
+ justify-content: space-between;
700
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
 
702
+ .app-header img {
703
+ max-width: 120px;
704
+ }
705
+
706
+ .app-title {
707
+ color: white;
708
+ margin: 0;
709
+ font-size: 24px;
710
+ font-weight: 600;
711
+ }
712
+
713
+ .status-card, .catalog-card, .chat-card {
714
+ background-color: var(--card-color);
715
+ border-radius: var(--border-radius);
716
+ padding: 15px;
717
+ margin-bottom: 20px;
718
+ box-shadow: var(--shadow);
719
+ }
720
+
721
+ .chat-card {
722
+ height: 100%;
723
+ }
724
+
725
+ .message {
726
+ padding: 10px 15px;
727
+ border-radius: 8px;
728
+ margin-bottom: 10px;
729
+ max-width: 85%;
730
+ }
731
+
732
+ .user-message {
733
+ background-color: var(--primary-color);
734
+ color: white;
735
+ margin-left: auto;
736
+ }
737
+
738
+ .bot-message {
739
+ background-color: #f0f0f0;
740
+ color: var(--text-color);
741
+ margin-right: auto;
742
+ }
743
+
744
+ .footer {
745
+ text-align: center;
746
+ margin-top: 20px;
747
+ font-size: 12px;
748
+ color: var(--text-color);
749
+ }
750
+
751
+ .action-button {
752
+ background-color: var(--primary-color);
753
+ color: white;
754
+ border: none;
755
+ border-radius: var(--border-radius);
756
+ padding: 8px 16px;
757
+ cursor: pointer;
758
+ transition: all 0.3s ease;
759
+ }
760
+
761
+ .action-button:hover {
762
+ opacity: 0.9;
763
+ }
 
 
 
 
 
 
 
 
 
 
 
764
  """
 
 
765
 
766
+ # Create the Gradio interface
767
+ with gr.Blocks(css=custom_css) as app:
768
+ # Setup status variable
769
+ setup_status = gr.State("System is setting up. Please wait...")
770
+ status_display = gr.Markdown("System is setting up. Please wait...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
 
772
+ with gr.Column(scale=1):
773
+ # Modern header
774
+ with gr.Row(elem_classes="app-header"):
775
+ with gr.Column(scale=1):
776
+ gr.Image(value="https://upload.wikimedia.org/wikipedia/commons/thumb/0/00/ABB_logo.svg/2560px-ABB_logo.svg.png",
777
+ width=120,
778
+ height=120,
779
+ interactive=False,
780
+ label="ABB Logo")
781
+ with gr.Column(scale=3):
782
+ gr.HTML('<h1 class="app-title">Ginnie</h1>')
783
+ gr.HTML('<p class="app-subtitle">Your AI assistant for ABB product information</p>')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
784
 
785
+ # Chat interface
786
+ with gr.Row():
787
+ with gr.Column(scale=3):
788
+ # Chat interface with custom styling
789
+ gr.HTML('<div class="content-card">')
790
+ chatbot = gr.Chatbot(
791
+ value=[],
792
+ elem_id="chatbot",
793
+ height=500,
794
+ show_copy_button=True,
795
+ avatar_images=["https://ui-avatars.com/api/?name=You&background=0D8ABC&color=fff",
796
+ "https://ui-avatars.com/api/?name=Ginnie&background=FF000C&color=fff"]
797
+ )
 
 
 
 
 
 
 
798
 
799
+ # Message input with better styling
800
+ with gr.Row(elem_classes="input-area"):
801
+ msg = gr.Textbox(
802
+ placeholder="Ask about ABB products...",
803
+ label="",
804
+ lines=2,
805
+ max_lines=5,
806
+ show_label=False
807
+ )
 
 
 
 
 
 
 
808
 
809
+ send_btn = gr.Button("Send", elem_classes="primary-button")
810
+
811
+ with gr.Row():
812
+ clear_btn = gr.Button("Clear Chat", elem_classes="secondary-button")
813
+ gr.HTML('</div>')
814
+
815
+ with gr.Column(scale=1):
816
+ # Quick tips card
817
+ gr.HTML('<div class="status-card">')
818
+ gr.HTML('''
819
+ <h3>Quick Tips</h3>
820
+ <ul>
821
+ <li>Ask about specific ABB products</li>
822
+ <li>Inquire about technical specifications</li>
823
+ <li>Ask about installation and maintenance</li>
824
+ <li>Get help with troubleshooting</li>
825
+ </ul>
826
+ ''')
827
+ gr.HTML('</div>')
828
+
829
+ # Admin settings
830
+ with gr.Accordion("Admin Settings", open=False):
831
+ with gr.Tab("Process PDFs"):
832
+ s3_bucket = gr.Textbox(label="S3 Bucket Name")
833
+ s3_prefix = gr.Textbox(label="S3 Prefix (folder)", value="catalogs/")
834
+ process_btn = gr.Button("Process PDFs from S3", elem_classes="action-button")
835
+ result_text = gr.Textbox(label="Processing Result")
836
+
837
+ # Set up event handlers
838
+ send_btn.click(
839
+ process_message,
840
+ [msg, chatbot],
841
+ [chatbot],
842
+ api_name="send_message"
843
+ )
844
+
845
+ msg.submit(
846
+ process_message,
847
+ [msg, chatbot],
848
+ [chatbot],
849
+ api_name="send_message_enter"
850
+ )
851
+
852
+ clear_btn.click(
853
+ reset_chat,
854
+ [chatbot],
855
+ [chatbot],
856
+ api_name="clear_chat"
857
+ )
858
+
859
+ process_btn.click(
860
+ process_pdfs_from_s3,
861
+ [s3_bucket, s3_prefix],
862
+ [result_text],
863
+ api_name="process_pdfs"
864
+ )
865
+
866
+ # Add the system setup to run when the app loads
867
+ app.load(setup_and_update, None, status_display)
868
+
869
+ return app
 
 
 
 
 
 
 
 
 
 
 
 
870
 
871
  # Start the application
872
  if __name__ == "__main__":
 
 
 
873
  # Create and launch the UI
874
+ demo = create_gradio_app()
875
  demo.launch(share=True)