Spaces:
Runtime error
Runtime error
| # chatbot_agent.py | |
| import os | |
| import json | |
| import re | |
| from openai import OpenAI | |
| import traceback | |
| import logging | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Get logger | |
| logger = logging.getLogger(__name__) | |
| # Initialize OpenAI client with error handling | |
| def get_openai_client(): | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OpenAI API key not found in environment variables") | |
| return OpenAI(api_key=api_key) | |
| def format_message(role, content): | |
| """Format message for chatbot history.""" | |
| return {"role": role, "content": content} | |
| def initialize_graph_prompt(graph_data): | |
| """Initialize the conversation with available node and edge information.""" | |
| try: | |
| # Get summary info with safe fallbacks | |
| summary = graph_data.get('summary', {}) | |
| summary_parts = [] | |
| # Only include counts that exist | |
| if 'symbol_count' in summary: | |
| summary_parts.append(f"Symbols: {summary['symbol_count']}") | |
| if 'text_count' in summary: | |
| summary_parts.append(f"Texts: {summary['text_count']}") | |
| if 'line_count' in summary: | |
| summary_parts.append(f"Lines: {summary['line_count']}") | |
| if 'edge_count' in summary: | |
| summary_parts.append(f"Edges: {summary['edge_count']}") | |
| summary_info = ", ".join(summary_parts) + "." | |
| # Prepare node details only if they exist | |
| node_details = "" | |
| detailed_results = graph_data.get('detailed_results', {}) | |
| if 'symbols' in detailed_results: | |
| node_details = "Nodes (symbols) in the graph include:\n" | |
| for symbol in detailed_results['symbols']: | |
| details = [] | |
| if 'symbol_id' in symbol: | |
| details.append(f"ID: {symbol['symbol_id']}") | |
| if 'class_id' in symbol: | |
| details.append(f"Class: {symbol['class_id']}") | |
| if 'category' in symbol: | |
| details.append(f"Category: {symbol['category']}") | |
| if 'type' in symbol: | |
| details.append(f"Type: {symbol['type']}") | |
| if 'label' in symbol: | |
| details.append(f"Label: {symbol['label']}") | |
| if details: # Only add if we have any details | |
| node_details += ", ".join(details) + "\n" | |
| initial_prompt = ( | |
| "You have access to a knowledge graph generated from a P&ID diagram. " | |
| f"The summary information includes:\n{summary_info}\n\n" | |
| f"{node_details}\n" | |
| "Answer questions about the P&ID elements using this information." | |
| ) | |
| return initial_prompt | |
| except Exception as e: | |
| logger.error(f"Error creating initial prompt: {str(e)}") | |
| return ("I have access to a P&ID diagram knowledge graph. " | |
| "I can help answer questions about the diagram elements.") | |
| def get_assistant_response(user_message, json_path): | |
| """Generate response based on P&ID data and OpenAI.""" | |
| try: | |
| client = get_openai_client() | |
| # Load the aggregated data | |
| with open(json_path, 'r') as f: | |
| data = json.load(f) | |
| # Process the user's question | |
| question = user_message.lower() | |
| # Use rule-based responses for specific questions | |
| if "valve" in question or "valves" in question: | |
| valve_count = sum(1 for symbol in data.get('symbols', []) | |
| if 'class' in symbol and 'valve' in symbol['class'].lower()) | |
| return f"I found {valve_count} valves in this P&ID." | |
| elif "pump" in question or "pumps" in question: | |
| pump_count = sum(1 for symbol in data.get('symbols', []) | |
| if 'class' in symbol and 'pump' in symbol['class'].lower()) | |
| return f"I found {pump_count} pumps in this P&ID." | |
| elif "equipment" in question or "components" in question: | |
| equipment_types = {} | |
| for symbol in data.get('symbols', []): | |
| if 'class' in symbol: | |
| eq_type = symbol['class'] | |
| equipment_types[eq_type] = equipment_types.get(eq_type, 0) + 1 | |
| response = "Here's a summary of the equipment I found:\n" | |
| for eq_type, count in equipment_types.items(): | |
| response += f"- {eq_type}: {count}\n" | |
| return response | |
| # For other questions, use OpenAI | |
| else: | |
| # Prepare the conversation context | |
| graph_data = { | |
| "summary": { | |
| "symbol_count": len(data.get('symbols', [])), | |
| "text_count": len(data.get('texts', [])), | |
| "line_count": len(data.get('lines', [])), | |
| "edge_count": len(data.get('edges', [])), | |
| }, | |
| "detailed_results": data | |
| } | |
| initial_prompt = initialize_graph_prompt(graph_data) | |
| conversation = [ | |
| {"role": "system", "content": initial_prompt}, | |
| {"role": "user", "content": user_message} | |
| ] | |
| response = client.chat.completions.create( | |
| model="gpt-4-turbo", | |
| messages=conversation | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| logger.error(f"Error in get_assistant_response: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return "I apologize, but I encountered an error analyzing the P&ID data. Please try asking a different question." | |
| # Testing and Usage block | |
| if __name__ == "__main__": | |
| # Load the knowledge graph data from JSON file | |
| json_file_path = "results/0_aggregated_detections.json" | |
| try: | |
| with open(json_file_path, 'r') as file: | |
| graph_data = json.load(file) | |
| except FileNotFoundError: | |
| print(f"Error: File not found at {json_file_path}") | |
| graph_data = None | |
| except json.JSONDecodeError: | |
| print("Error: Failed to decode JSON. Please check the file format.") | |
| graph_data = None | |
| # Initialize conversation history with assistant's welcome message | |
| history = [format_message("assistant", "Hello! I am ready to answer your questions about the P&ID knowledge graph. The graph includes nodes (symbols), edges, linkers, and text tags, and I have detailed information available about each. Please ask any questions related to these elements and their connections.")] | |
| # Print the assistant's welcome message | |
| print("Assistant:", history[0]["content"]) | |
| # Individual Testing Options | |
| if graph_data: | |
| # Option 1: Test the graph prompt initialization | |
| print("\n--- Test: Graph Prompt Initialization ---") | |
| initial_prompt = initialize_graph_prompt(graph_data) | |
| print(initial_prompt) | |
| # Option 2: Simulate a conversation with a test question | |
| print("\n--- Test: Simulate Conversation ---") | |
| test_question = "Can you tell me about the connections between the nodes?" | |
| history.append(format_message("user", test_question)) | |
| print(f"\nUser: {test_question}") | |
| for response in get_assistant_response(test_question, json_file_path): | |
| print("Assistant:", response) | |
| history.append(format_message("assistant", response)) | |
| # Option 3: Manually input questions for interactive testing | |
| while True: | |
| user_question = input("\nYou: ") | |
| if user_question.lower() in ["exit", "quit"]: | |
| print("Exiting chat. Goodbye!") | |
| break | |
| history.append(format_message("user", user_question)) | |
| for response in get_assistant_response(user_question, json_file_path): | |
| print("Assistant:", response) | |
| history.append(format_message("assistant", response)) | |
| else: | |
| print("Unable to load graph data. Please check the file path and format.") | |