Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import logging | |
| import os | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| import faiss | |
| from simple_salesforce import Salesforce | |
| from dotenv import load_dotenv | |
| import zipfile | |
| from pathlib import Path | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables from .env file | |
| load_dotenv() # Load the .env file | |
| # Get the Salesforce credentials from environment variables | |
| sf_username = os.getenv("SF_USERNAME") | |
| sf_password = os.getenv("SF_PASSWORD") | |
| sf_security_token = os.getenv("SF_SECURITY_TOKEN") | |
| sf_instance_url = os.getenv("SF_INSTANCE_URL") | |
| # Check if the environment variables are correctly set | |
| if not sf_username or not sf_password or not sf_security_token or not sf_instance_url: | |
| logger.error("❌ Salesforce credentials are missing from environment variables!") | |
| raise ValueError("Salesforce credentials are not properly set.") | |
| # Salesforce connection | |
| try: | |
| sf = Salesforce( | |
| username=sf_username, | |
| password=sf_password, | |
| security_token=sf_security_token, | |
| instance_url=sf_instance_url | |
| ) | |
| logger.info("✅ Connected to Salesforce") | |
| except Exception as e: | |
| logger.error(f"❌ Salesforce connection failed: {str(e)}") | |
| raise | |
| # --- Extract zip files and read documents --- | |
| def extract_zip(zip_path, extract_to): | |
| try: | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| zip_ref.extractall(extract_to) | |
| logger.info(f"Extracted {zip_path} to {extract_to}") | |
| except Exception as e: | |
| logger.error(f"Failed to extract {zip_path}: {str(e)}") | |
| raise | |
| def load_documents(folder_path): | |
| documents = [] | |
| sources = [] | |
| for file in Path(folder_path).rglob("*.txt"): | |
| text = file.read_text(encoding="utf-8", errors="ignore") | |
| documents.append(text) | |
| sources.append(file.name) | |
| return documents, sources | |
| # --- Chunking --- | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=50) | |
| # --- Load model --- | |
| model = SentenceTransformer("all-MiniLM-L6-v2") | |
| # --- Preprocessing --- | |
| data_dir = Path("./data") | |
| data_dir.mkdir(exist_ok=True) | |
| doc_folders = [ | |
| ("Company_Policies.zip", "Company_Policies"), | |
| ("HR_Policies.zip", "Hr_Policies"), | |
| ("Contract_Clauses.zip", "Contract_Clauses") | |
| ] | |
| all_chunks = [] | |
| metadata = [] | |
| for zip_name, folder in doc_folders: | |
| zip_path = Path(zip_name) | |
| if not zip_path.exists(): | |
| logger.error(f"Zip file {zip_name} not found") | |
| raise FileNotFoundError(f"Zip file {zip_name} not found") | |
| extract_path = data_dir / folder | |
| extract_path.mkdir(exist_ok=True) | |
| extract_zip(zip_path, extract_path) | |
| docs, sources = load_documents(extract_path) | |
| if not docs: | |
| logger.error(f"No documents found in {extract_path}") | |
| raise ValueError(f"No documents found in {extract_path}") | |
| for doc, src in zip(docs, sources): | |
| chunks = text_splitter.split_text(doc) | |
| all_chunks.extend(chunks) | |
| src_url = f"https://company.com/{folder}/{src}" | |
| metadata.extend([src_url] * len(chunks)) | |
| # --- Embeddings + FAISS index --- | |
| embeddings = model.encode(all_chunks) | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(np.array(embeddings)) | |
| logger.info("FAISS index built successfully") | |
| # --- Create Record in Salesforce --- | |
| def create_salesforce_record(query, answer, confidence_percentage, source_link): | |
| try: | |
| # Convert the confidence_percentage to Python float (to avoid numpy float32) | |
| confidence_percentage = float(confidence_percentage) | |
| # Data with correctly mapped field names | |
| data = { | |
| "Query__c": query, # Field for User Query | |
| "Answer__c": answer, # Field for Answer | |
| "Confidence_Percentage__c": confidence_percentage, # Field for Confidence Score | |
| "Document_link__c": source_link, # Field for Document Link | |
| } | |
| # Creating the record in Salesforce | |
| response = sf.chat_query_log__c.create(data) | |
| # Check if record was created successfully | |
| if 'id' in response: # If the response contains an 'id', the record is created successfully | |
| record_id = response['id'] | |
| logger.info(f"✅ Record created successfully in Salesforce with ID: {record_id}") | |
| return record_id # Return the Salesforce record ID | |
| else: | |
| # Log the failure response | |
| logger.error(f"❌ Failed to create Salesforce record. Response: {response}") | |
| return None | |
| except Exception as e: | |
| # Log any error during record creation | |
| logger.error(f"Error creating Salesforce record: {str(e)}") | |
| return None | |
| # --- Search & Answer --- | |
| def answer_query(query): | |
| try: | |
| logger.info(f"Processing query: {query}") | |
| query_embedding = model.encode([query]) | |
| D, I = index.search(np.array(query_embedding), k=3) | |
| top_chunks = [all_chunks[i] for i in I[0]] | |
| top_sources = [metadata[i] for i in I[0]] | |
| distances = D[0] | |
| relevant_chunks = [ | |
| chunk for chunk, dist in zip(top_chunks, distances) if dist < 0.8 | |
| ] | |
| relevant_sources = [ | |
| src for src, dist in zip(top_sources, distances) if dist < 0.8 | |
| ] | |
| if not relevant_chunks: | |
| return "No relevant information found.", "Confidence: 0%", "Source Link: None" | |
| answer = relevant_chunks[0].strip() | |
| min_distance = min(distances) | |
| confidence_percentage = max(0, 100 - (min_distance * 100)) | |
| source_link = relevant_sources[0] if relevant_sources else "None" | |
| # Create Salesforce record for the query response | |
| record_id = create_salesforce_record(query, answer, confidence_percentage, source_link) | |
| if record_id: | |
| return ( | |
| answer, | |
| f"Confidence: {confidence_percentage:.2f}%", | |
| f"Source Link: {source_link}", | |
| f"Salesforce Record ID: {record_id}" # Display the Salesforce record ID | |
| ) | |
| else: | |
| return ( | |
| answer, | |
| f"Confidence: {confidence_percentage:.2f}%", | |
| f"Source Link: {source_link}", | |
| "Failed to create record in Salesforce" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in answer_query: {str(e)}") | |
| return f"Error: {str(e)}", "", "", "" | |
| # --- Gradio Chatbot UI Design --- | |
| def process_question(q, chat_history): | |
| if not q.strip(): | |
| return chat_history + [("User", "Please enter a question.")], "", "" | |
| answer, confidence, source, record_id = answer_query(q) | |
| chat_history.append(("User", q)) | |
| chat_history.append(("Bot", answer)) | |
| return chat_history, confidence, source, record_id | |
| # --- Chatbot UI with dynamic styling using elem_id --- | |
| with gr.Blocks(title="Company Documents Q&A Chatbot", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## 📚 **Company Policies Q&A Chatbot**") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| question = gr.Textbox( | |
| label="Ask a Question", | |
| placeholder="What are the conditions for permanent employment status?", | |
| lines=1, | |
| interactive=True, | |
| elem_id="user-question", | |
| visible=True | |
| ) | |
| with gr.Column(scale=1): | |
| submit_btn = gr.Button("Submit", variant="primary", elem_id="submit-btn") | |
| with gr.Row(): | |
| with gr.Column(): | |
| chat_history = gr.Chatbot( | |
| label="Chat History", | |
| elem_id="chatbox", | |
| height=400, # Set a fixed height | |
| show_label=False # Hide the label to make the chat more clean | |
| ) | |
| conf_out = gr.Markdown(label="Confidence", elem_id="confidence") | |
| source_out = gr.Markdown(label="Source Link", elem_id="source-link") | |
| record_out = gr.Markdown(label="Salesforce Record ID", elem_id="salesforce-id") | |
| submit_btn.click(fn=process_question, inputs=[question, chat_history], outputs=[chat_history, conf_out, source_out, record_out]) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |