Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import requests | |
| import gradio as gr | |
| import threading | |
| import time | |
| import PyPDF2 | |
| import chromadb | |
| import shutil | |
| from pydantic import BaseModel, Field | |
| from typing import Dict | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| API_KEY = os.getenv("mistral") | |
| BASE_URL = "https://api.together.xyz" | |
| # Store user inputs | |
| user_inputs = { | |
| "organization": "", | |
| "rules_l1": "", | |
| "rules_l2": "", | |
| "rules_l3": "", | |
| } | |
| # Function to classify query | |
| def classify_query(query: str) -> Dict: | |
| if not all(user_inputs.values()): | |
| raise ValueError("Please fill all input fields first.") | |
| messages = [ | |
| {"role": "system", "content": f"""You are a Customer Query Classification Agent for {user_inputs["organization"]}. | |
| What is considered Level 1 Query (Requires no account info just provided documents by the admin is enough to answer): | |
| {user_inputs["rules_l1"]} | |
| What is considered Level 2 Query (Requires account info and provided documents by the admin is enough to answer): | |
| {user_inputs["rules_l2"]} | |
| What is considered as Level 3 Query (Immediate Escalation to Human Customer Service Agents): | |
| {user_inputs["rules_l3"]} | |
| Classify the following customer query and provide the output in JSON format: | |
| ```json | |
| {{ | |
| "title": "title of the query in under 10 words", | |
| "level": "1 or 2 or 3" | |
| }} | |
| ```"""}, | |
| {"role": "user", "content": query} | |
| ] | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {API_KEY}" | |
| } | |
| data = { | |
| "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| "messages": messages, | |
| "temperature": 0.7, | |
| "response_format": { | |
| "type": "json_object", | |
| "schema": { | |
| "type": "object", | |
| "properties": { | |
| "title": {"type": "string"}, | |
| "level": {"type": "integer"} | |
| }, | |
| "required": ["title", "level"] | |
| } | |
| } | |
| } | |
| response = requests.post(f"{BASE_URL}/chat/completions", headers=headers, json=data) | |
| response.raise_for_status() | |
| classification_result = response.json().get('choices')[0].get('message').get('content') | |
| return classification_result | |
| # Function to convert PDF to text | |
| def pdf_to_text(file_path): | |
| pdf_file = open(file_path, 'rb') | |
| pdf_reader = PyPDF2.PdfReader(pdf_file) | |
| text = "" | |
| for page_num in range(len(pdf_reader.pages)): | |
| text += pdf_reader.pages[page_num].extract_text() | |
| pdf_file.close() | |
| return text | |
| # Function to handle file upload and save embeddings to ChromaDB | |
| def handle_file_upload(files, collection_name): | |
| if not collection_name: | |
| return "Please provide a collection name." | |
| os.makedirs('chabot_pdfs', exist_ok=True) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small") | |
| # Initialize Chroma DB client | |
| client = chromadb.PersistentClient(path="./db") | |
| try: | |
| collection = client.create_collection(name=collection_name) | |
| except ValueError as e: | |
| return f"Error creating collection: {str(e)}. Please try a different collection name." | |
| for file in files: | |
| file_name = os.path.basename(file.name) | |
| file_path = os.path.join('chabot_pdfs', file_name) | |
| shutil.copy(file.name, file_path) # Copy the file instead of saving | |
| text = pdf_to_text(file_path) | |
| chunks = text_splitter.split_text(text) | |
| documents_list = [] | |
| embeddings_list = [] | |
| ids_list = [] | |
| for i, chunk in enumerate(chunks): | |
| vector = embeddings.embed_query(chunk) | |
| documents_list.append(chunk) | |
| embeddings_list.append(vector) | |
| ids_list.append(f"{file_name}_{i}") | |
| collection.add( | |
| embeddings=embeddings_list, | |
| documents=documents_list, | |
| ids=ids_list | |
| ) | |
| return "Files uploaded and processed successfully." | |
| # Function to search vector database | |
| def search_vector_database(query, collection_name): | |
| if not collection_name: | |
| return "Please provide a collection name." | |
| embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-small") | |
| client = chromadb.PersistentClient(path="./db") | |
| try: | |
| collection = client.get_collection(name=collection_name) | |
| except ValueError as e: | |
| return f"Error accessing collection: {str(e)}. Make sure the collection name is correct." | |
| query_vector = embeddings.embed_query(query) | |
| results = collection.query(query_embeddings=[query_vector], n_results=2, include=["documents"]) | |
| return "\n\n".join("\n".join(result) for result in results["documents"]) | |
| # New function to handle login | |
| def handle_login(username, password): | |
| # This is a simple example. In a real application, you'd want to use secure authentication methods. | |
| if username == "admin" and password == "password": | |
| return """ | |
| "NeoBank": { | |
| "user_id": "NB782940", | |
| "user_name": "john_doe123", | |
| "full_name": "John Doe", | |
| "email": "john.doe@example.com", | |
| "balance": 2875.43, | |
| "transactions": [ | |
| {"date": "2024-06-20", "description": "Coffee Shop", "amount": -4.50}, | |
| {"date": "2024-06-19", "description": "Grocery Store", "amount": -85.22}, | |
| {"date": "2024-06-18", "description": "Salary Deposit", "amount": 2500.00} | |
| ] | |
| }, | |
| "CryptoInvest": { | |
| "user_id": "CI549217", | |
| "user_name": "crypto_enthusiast", | |
| "full_name": "Alice Johnson", | |
| "email": "alice.johnson@example.com", | |
| "portfolio": { | |
| "BTC": {"amount": 0.025, "value": 7500.00}, | |
| "ETH": {"amount": 1.2, "value": 2100.00}, | |
| "SOL": {"amount": 5.8, "value": 450.50} | |
| }, | |
| "transactions": [ | |
| {"date": "2024-06-22", "description": "Bought ETH", "amount": -500.00}, | |
| {"date": "2024-06-20", "description": "Sold BTC", "amount": 1200.00} | |
| ] | |
| }, | |
| "RoboAdvisor": { | |
| "user_id": "RA385712", | |
| "user_name": "jane_smith", | |
| "full_name": "Jane Smith", | |
| "email": "jane.smith@example.com", | |
| "risk_tolerance": "moderate", | |
| "portfolio_value": 15800.75, | |
| "allocations": { | |
| "stocks": 0.60, | |
| "bonds": 0.30, | |
| "real_estate": 0.10 | |
| }, | |
| "recent_activity": [ | |
| {"date": "2024-06-21", "description": "Dividends received", "amount": 32.50}, | |
| {"date": "2024-06-15", "description": "Portfolio rebalanced" } | |
| ] | |
| }, | |
| "PeerLend": { | |
| "user_id": "PL916350", | |
| "user_name": "bob_williams", | |
| "full_name": "Bob Williams", | |
| "email": "bob.williams@example.com", | |
| "account_type": "borrower", | |
| "loan_amount": 5000.00, | |
| "interest_rate": 7.8, | |
| "monthly_payment": 150.30, | |
| "payment_history": [ | |
| {"date": "2024-06-22", "status": "paid"}, | |
| {"date": "2024-05-22", "status": "paid"}, | |
| {"date": "2024-04-22", "status": "paid"} | |
| ] | |
| }, | |
| "InsureTech": { | |
| "user_id": "IT264805", | |
| "user_name": "eva_brown4", | |
| "full_name": "Eva Brown", | |
| "email": "eva.brown@example.com", | |
| "policy_type": "auto", | |
| "coverage_details": { | |
| "liability": "50/100/50", | |
| "collision": "500 deductible", | |
| "comprehensive": "100 deductible" | |
| }, | |
| "premium": 85.50, | |
| "next_payment": "2024-07-10", | |
| "claims": [] | |
| } | |
| """ | |
| else: | |
| return "Invalid username or password" | |
| # Gradio interface | |
| def gradio_interface(): | |
| with gr.Blocks(theme='gl198976/The-Rounded') as interface: | |
| gr.Markdown("# Admin Dashboard🧖🏻♀️") | |
| with gr.Tab("Query Classifier Agent"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| organization_input = gr.Textbox(label="Organization Name") | |
| rules_l1_input = gr.Textbox(label="Rules for Level 1 Query", lines=5) | |
| rules_l2_input = gr.Textbox(label="Rules for Level 2 Query", lines=5) | |
| rules_l3_input = gr.Textbox(label="Rules for Level 3 Query", lines=5) | |
| submit_btn = gr.Button("Submit Rules") | |
| with gr.Column(): | |
| query_input = gr.Textbox(label="Customer Query") | |
| classification_output = gr.Textbox(label="Classification Result") | |
| classify_btn = gr.Button("Classify Query") | |
| api_details = gr.Markdown(""" | |
| ### API Endpoint Details | |
| - **URL:** `http://0.0.0.0:7860/classify` | |
| - **Method:** POST | |
| - **Request Body:** JSON with a single key `query` | |
| - **Example Usage:** | |
| ```python | |
| from gradio_client import Client | |
| client = Client("http://0.0.0.0:7860/") | |
| result = client.predict( | |
| "Hello!!", # str in 'Customer Query' Textbox component | |
| api_name="/classify_and_display" | |
| ) | |
| print(result) | |
| ``` | |
| """) | |
| submit_btn.click(lambda org, r1, r2, r3: ( | |
| setattr(user_inputs, "organization", org), | |
| setattr(user_inputs, "rules_l1", r1), | |
| setattr(user_inputs, "rules_l2", r2), | |
| setattr(user_inputs, "rules_l3", r3) | |
| ), inputs=[organization_input, rules_l1_input, rules_l2_input, rules_l3_input]) | |
| classify_btn.click(classify_query, inputs=[query_input], outputs=[classification_output]) | |
| with gr.Tab("Organization Documentation Agent"): | |
| gr.Markdown(""" | |
| ### Warning | |
| If you encounter an error when uploading files, try changing the collection name and upload again. | |
| Each collection name must be unique. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| collection_name_input = gr.Textbox(label="Collection Name", placeholder="Enter a unique name for this collection") | |
| file_upload = gr.Files(file_types=[".pdf"], label="Upload PDFs") | |
| upload_btn = gr.Button("Upload and Process Files") | |
| upload_status = gr.Textbox(label="Upload Status", interactive=False) | |
| with gr.Column(): | |
| search_query_input = gr.Textbox(label="Search Query") | |
| search_output = gr.Textbox(label="Search Results", lines=10) | |
| search_btn = gr.Button("Search") | |
| api_details = gr.Markdown(""" | |
| ### API Endpoint Details | |
| - **URL:** `http://0.0.0.0:7860/search_vector_database` | |
| - **Method:** POST | |
| - **Example Usage:** | |
| ```python | |
| from gradio_client import Client | |
| client = Client("http://0.0.0.0:7860/") | |
| result = client.predict( | |
| "search query", # str in 'Search Query' Textbox component | |
| "name of collection given in ui", # str in 'Collection Name' Textbox component | |
| api_name="/search_vector_database" | |
| ) | |
| print(result) | |
| ``` | |
| """) | |
| upload_btn.click(handle_file_upload, inputs=[file_upload, collection_name_input], outputs=[upload_status]) | |
| search_btn.click(search_vector_database, inputs=[search_query_input, collection_name_input], outputs=[search_output]) | |
| with gr.Tab("Account Information"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| username_input = gr.Textbox(label="Username") | |
| password_input = gr.Textbox(label="Password", type="password") | |
| login_btn = gr.Button("Login") | |
| with gr.Column(): | |
| account_info_output = gr.Textbox(label="Account Info", lines=20) | |
| api_details = gr.Markdown(""" | |
| ### API Endpoint Details | |
| - **URL:** `http://0.0.0.0:7860/handle_login` | |
| - **Method:** POST | |
| - **Example Usage:** | |
| ```python | |
| from gradio_client import Client | |
| client = Client("http://0.0.0.0:7860/") | |
| result = client.predict( | |
| "admin", # str in 'Username' Textbox component | |
| "password", # str in 'Password' Textbox component | |
| api_name="/handle_login" | |
| ) | |
| print(result) | |
| ``` | |
| """) | |
| login_btn.click(handle_login, inputs=[username_input, password_input], outputs=[account_info_output]) | |
| interface.launch(server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| gradio_interface() | |