import os
import base64
import gradio as gr
import json
from datetime import datetime
from symbol_detection import run_detection_with_optimal_threshold
from line_detection_ai import DiagramDetectionPipeline, LineDetector, LineConfig, ImageConfig, DebugHandler, \
PointConfig, JunctionConfig, PointDetector, JunctionDetector, SymbolConfig, SymbolDetector, TagConfig, TagDetector
from data_aggregation_ai import DataAggregator
from chatbot_agent import get_assistant_response
from storage import StorageFactory, LocalStorage
import traceback
from text_detection_combined import process_drawing
from pathlib import Path
from pdf_processor import DocumentProcessor
import networkx as nx
import logging
import matplotlib.pyplot as plt
from dotenv import load_dotenv
import torch
from graph_visualization import create_graph_visualization
import shutil
from detection_schema import BBox # Add this import
import cv2
import numpy as np
import time
from huggingface_hub import HfApi, login
from download_models import download_from_azure
# Load environment variables from .env file
load_dotenv()
# Configure logging at the start of the file
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# Get logger for this module
logger = logging.getLogger(__name__)
# Disable duplicate logs from other modules
logging.getLogger('PIL').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger('gradio').setLevel(logging.WARNING)
logging.getLogger('networkx').setLevel(logging.WARNING)
logging.getLogger('line_detection_ai').setLevel(logging.WARNING)
logging.getLogger('symbol_detection').setLevel(logging.WARNING)
# Only log important messages
def log_process_step(message, level=logging.INFO):
"""Log processing steps with appropriate level"""
if level >= logging.WARNING:
logger.log(level, message)
elif "completed" in message.lower() or "generated" in message.lower():
logger.info(message)
# Helper function to format timestamps
def get_timestamp():
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
def format_message(role, content):
"""Format message for chatbot history."""
return {"role": role, "content": content}
# Load avatar images for agents
localStorage = LocalStorage()
agent_avatar = base64.b64encode(localStorage.load_file("assets/AiAgent.png")).decode()
llm_avatar = base64.b64encode(localStorage.load_file("assets/llm.png")).decode()
user_avatar = base64.b64encode(localStorage.load_file("assets/user.png")).decode()
# Chat message formatting with avatars and enhanced HTML for readability
def chat_message(role, message, avatar, timestamp):
# Convert Markdown-style formatting to HTML
formatted_message = (
message.replace("**", "").replace("**", "")
.replace("###", "
").replace("##", "")
.replace("#", "").replace("\n", "
")
.replace("```", "").replace("`", "
")
.replace("\n1. ", "
1. ") # For ordered lists starting with "1."
.replace("\n2. ", "
2. ")
.replace("\n3. ", "
3. ")
.replace("\n4. ", "
4. ")
.replace("\n5. ", "
5. ")
)
return f"""
{formatted_message}
{timestamp}
"""
def resize_to_fit(image_path, max_width=800, max_height=600):
"""Resize image to fit editor while maintaining aspect ratio"""
# Read image
img = cv2.imread(image_path)
if img is None:
return None, 1.0
# Get original dimensions
h, w = img.shape[:2]
# Calculate scale factor to fit within max dimensions
scale_w = max_width / w
scale_h = max_height / h
scale = min(scale_w, scale_h)
# Always resize to fit the editor window
new_w = int(w * scale)
new_h = int(h * scale)
resized = cv2.resize(img, (new_w, new_h))
return resized, scale
# Main processing function for P&ID steps
def process_pnid(image_file, progress=gr.Progress()):
"""Process P&ID document with real-time progress updates."""
try:
if image_file is None:
raise ValueError("No file uploaded. Please upload a file first.")
progress_text = []
outputs = [None] * 9 # Changed from 8 to 9 to match UI outputs
base_name = os.path.splitext(os.path.basename(image_file.name))[0] + "_page_1"
# Initialize chat history with proper format
chat_history = [{"role": "assistant", "content": "Welcome! Upload a P&ID to begin analysis."}]
outputs[7] = chat_history # Chat history moved to index 7
def update_progress(step, message):
progress_text.append(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - {message}")
outputs[0] = "\n".join(progress_text) # Progress text
progress(step)
# Initialize storage and results directory
storage = StorageFactory.get_storage()
results_dir = "results"
os.makedirs(results_dir, exist_ok=True)
# Clean results directory
logger.info("Cleaned results directory: results")
for file in os.listdir(results_dir):
file_path = os.path.join(results_dir, file)
try:
if os.path.isfile(file_path):
os.unlink(file_path)
except Exception as e:
logger.error(f"Error deleting file {file_path}: {str(e)}")
# Step 1: File Upload (10%)
logger.info(f"Processing file: {os.path.basename(image_file.name)}")
update_progress(0.1, "Step 1/7: File uploaded successfully")
yield outputs
# Step 2: Document Processing - Get high quality PNG
update_progress(0.2, "Step 2/7: Processing document...")
doc_processor = DocumentProcessor(storage)
processed_pages = doc_processor.process_document(
file_path=image_file,
output_dir=results_dir
)
if not processed_pages:
raise ValueError("No pages processed from document")
# Use high quality PNG for everything
high_quality_png = processed_pages[0]
outputs[1] = high_quality_png # P&ID Tab shows original high quality
update_progress(0.25, "Document loaded and displayed")
yield outputs
# Step 3: Symbol Detection using high quality PNG
detection_image_path, detection_json_path, _, diagram_bbox = run_detection_with_optimal_threshold(
high_quality_png, # Use high quality PNG
results_dir=results_dir,
file_name=os.path.basename(high_quality_png),
storage=storage,
resize_image=False # Don't resize
)
outputs[2] = detection_image_path # Symbols Tab
symbol_json_path = detection_json_path
# Step 4: Text Detection using high quality PNG
text_results, text_summary = process_drawing(
high_quality_png, # Use high quality PNG
results_dir,
storage
)
text_json_path = text_results['json_path']
outputs[3] = text_results['image_path'] # Tags Tab
# Step 5: Line Detection (80%)
update_progress(0.80, "Step 5/7: Line Detection")
yield outputs
try:
# Initialize components
debug_handler = DebugHandler(enabled=True, storage=storage)
# Configure detectors
line_config = LineConfig()
point_config = PointConfig()
junction_config = JunctionConfig()
symbol_config = SymbolConfig(
model_path="models/Intui_SDM_41.pt",
confidence_threshold=0.5,
nms_threshold=0.3
)
tag_config = TagConfig(
model_path="models/tag_detection.json",
confidence_threshold=0.5
)
# Create all required detectors
symbol_detector = SymbolDetector(
config=symbol_config,
debug_handler=debug_handler
)
tag_detector = TagDetector(
config=tag_config,
debug_handler=debug_handler
)
line_detector = LineDetector(
config=line_config,
model_path="models/deeplsd_md.tar",
model_config={"detect_lines": True},
device=torch.device("cuda"),
debug_handler=debug_handler
)
point_detector = PointDetector(
config=point_config,
debug_handler=debug_handler
)
junction_detector = JunctionDetector(
config=junction_config,
debug_handler=debug_handler
)
# Create pipeline with all detectors
pipeline = DiagramDetectionPipeline(
tag_detector=tag_detector,
symbol_detector=symbol_detector,
line_detector=line_detector,
point_detector=point_detector,
junction_detector=junction_detector,
storage=storage,
debug_handler=debug_handler
)
# Run pipeline with original high-res image
line_results = pipeline.run(
image_path=high_quality_png,
output_dir=results_dir,
config=ImageConfig()
)
line_json_path = line_results.json_path
outputs[4] = line_results.image_path
# Verify line detection output
if not os.path.exists(line_json_path):
raise ValueError(f"Line detection JSON not found: {line_json_path}")
# Verify line detection JSON content
with open(line_json_path, 'r') as f:
line_data = json.load(f)
if 'lines' not in line_data:
raise ValueError(f"Invalid line detection data format in {line_json_path}")
logger.info(f"Line detection completed successfully with {len(line_data['lines'])} lines")
# Verify all required JSONs exist before aggregation
required_jsons = {
'symbols': symbol_json_path,
'texts': text_json_path,
'lines': line_json_path
}
for name, path in required_jsons.items():
if not os.path.exists(path):
raise ValueError(f"{name} JSON not found: {path}")
# Verify JSON can be loaded
with open(path, 'r') as f:
data = json.load(f)
logger.info(f"Loaded {name} JSON with {len(data.get('detections', data.get('lines', [])))} items")
# Data Aggregation
aggregator = DataAggregator(storage=storage)
aggregated_result = aggregator.process_data(
image_path=high_quality_png,
output_dir=results_dir,
symbols_path=symbol_json_path,
texts_path=text_json_path,
lines_path=line_json_path
)
# Verify aggregation result before graph creation
if not aggregated_result.get('success'):
raise ValueError(f"Data aggregation failed: {aggregated_result.get('error')}")
aggregated_json_path = aggregated_result['json_path']
if not os.path.exists(aggregated_json_path):
raise ValueError(f"Aggregated JSON not found: {aggregated_json_path}")
# Verify aggregated JSON content
with open(aggregated_json_path, 'r') as f:
aggregated_data = json.load(f)
required_keys = ['nodes', 'edges', 'symbols', 'texts', 'lines']
missing_keys = [k for k in required_keys if k not in aggregated_data]
if missing_keys:
raise ValueError(f"Aggregated JSON missing required keys: {missing_keys}")
logger.info("Aggregation completed successfully with:")
logger.info(f"- {len(aggregated_data['nodes'])} nodes")
logger.info(f"- {len(aggregated_data['edges'])} edges")
# After aggregation, create graph visualization
update_progress(0.85, "Step 6/7: Creating Knowledge Graph")
try:
# Create graph visualization
graph_results = create_graph_visualization(
json_path=aggregated_json_path,
output_dir=results_dir,
base_name=base_name,
save_plot=True
)
if not graph_results.get('success'):
logger.error(f"Error in graph generation: {graph_results.get('error')}")
raise Exception(graph_results.get('error'))
graph_path = f"results/{base_name}_graph_visualization.png"
if not os.path.exists(graph_path):
raise Exception("Graph visualization file not created")
update_progress(0.90, "Step 6/7: Knowledge Graph Created")
except Exception as e:
logger.error(f"Error creating graph visualization: {str(e)}")
raise
# Fix output assignments
outputs[0] = progress_text # Progress text
outputs[1] = high_quality_png # P&ID
outputs[2] = detection_image_path # Symbols
outputs[3] = text_results['image_path'] # Tags
outputs[4] = line_results.image_path # Lines
outputs[5] = f"results/{base_name}_aggregated.png" # Aggregated
outputs[6] = graph_path # Graph visualization
outputs[7] = chat_history # Chat
outputs[8] = aggregated_json_path # JSON state
# Update progress with all steps
update_progress(0.95, "Step 7/7: Finalizing Results")
chat_history = [{"role": "assistant", "content": "Processing complete! I can help answer questions about the P&ID contents."}]
outputs[7] = chat_history
update_progress(1.0, "✅ Processing Complete")
yield outputs
except Exception as e:
# Update chat with error message
chat_history = [{"role": "assistant", "content": f"Error during processing: {str(e)}"}]
outputs[7] = chat_history
raise
except Exception as e:
logger.error(f"Error in process_pnid: {str(e)}")
logger.error(f"Stack trace:\n{traceback.format_exc()}")
# Update chat with error message
chat_history = [{"role": "assistant", "content": f"Error: {str(e)}"}]
outputs[7] = chat_history
raise
# Separate function for Chat interaction
def handle_user_message(user_input, chat_history, json_path_state):
"""Handle user messages and generate responses."""
try:
if not user_input or not user_input.strip():
return chat_history
# Add user message
timestamp = get_timestamp()
new_history = chat_history + chat_message("user", user_input, user_avatar, timestamp)
# Check if json_path exists and is valid
if not json_path_state or not os.path.exists(json_path_state):
error_message = "Please upload and process a P&ID document first."
return new_history + chat_message("assistant", error_message, agent_avatar, get_timestamp())
try:
# Log for debugging
logger.info(f"Sending question to assistant: {user_input}")
logger.info(f"Using JSON path: {json_path_state}")
# Generate response
response = get_assistant_response(user_input, json_path_state)
# Handle the response
if isinstance(response, (str, dict)):
response_text = str(response)
else:
try:
# Try to get the first response from generator
response_text = next(response) if hasattr(response, '__next__') else str(response)
except StopIteration:
response_text = "I apologize, but I couldn't generate a response."
except Exception as e:
logger.error(f"Error processing response: {str(e)}")
response_text = "I apologize, but I encountered an error processing your request."
logger.info(f"Generated response: {response_text}")
if not response_text.strip():
response_text = "I apologize, but I couldn't generate a response. Please try asking your question differently."
# Add response to chat history
new_history += chat_message("assistant", response_text, agent_avatar, get_timestamp())
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
logger.error(traceback.format_exc())
error_message = "I apologize, but I encountered an error processing your request. Please try again."
new_history += chat_message("assistant", error_message, agent_avatar, get_timestamp())
return new_history
except Exception as e:
logger.error(f"Chat error: {str(e)}")
logger.error(traceback.format_exc())
return chat_history + chat_message(
"assistant",
"I apologize, but something went wrong. Please try again.",
agent_avatar,
get_timestamp()
)
# Update custom CSS
custom_css = """
.full-height-row {
height: calc(100vh - 150px); /* Adjusted height */
margin: 0;
padding: 10px;
}
.upload-box {
background: #2a2a2a;
border-radius: 8px;
padding: 15px;
margin-bottom: 15px;
border: 1px solid #3a3a3a;
}
.status-box-container {
background: #2a2a2a;
border-radius: 8px;
padding: 15px;
height: calc(100vh - 350px); /* Reduced height */
border: 1px solid #3a3a3a;
margin-bottom: 15px;
}
.status-box {
font-family: 'Courier New', monospace;
font-size: 12px;
line-height: 1.4;
background-color: #1a1a1a;
color: #00ff00;
padding: 10px;
border-radius: 5px;
height: calc(100% - 40px); /* Adjust for header */
overflow-y: auto;
white-space: pre-wrap;
word-wrap: break-word;
border: none;
}
.preview-tabs {
height: calc(100vh - 100px); /* Increased container height from 200px */
background: #2a2a2a;
border-radius: 8px;
padding: 15px;
border: 1px solid #3a3a3a;
margin-bottom: 15px;
}
.chat-container {
height: 100%; /* Take full height */
display: flex;
flex-direction: column;
background: #2a2a2a;
border-radius: 8px;
padding: 15px;
border: 1px solid #3a3a3a;
}
.chatbox {
flex: 1; /* Take remaining space */
overflow-y: auto;
background: #1a1a1a;
border-radius: 8px;
padding: 15px;
margin-bottom: 15px;
color: #ffffff;
min-height: 200px; /* Ensure minimum height */
}
.chat-input-group {
height: auto; /* Allow natural height */
min-height: 120px; /* Minimum height for input area */
background: #1a1a1a;
border-radius: 8px;
padding: 15px;
margin-top: auto; /* Push to bottom */
}
.chat-input {
background: #2a2a2a;
color: #ffffff;
border: 1px solid #3a3a3a;
border-radius: 5px;
padding: 12px;
min-height: 80px;
width: 100%;
margin-bottom: 10px;
}
.send-button {
width: 100%;
background: #4a4a4a;
color: #ffffff;
border-radius: 5px;
border: none;
padding: 12px;
cursor: pointer;
transition: background-color 0.3s;
}
.result-image {
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
margin: 10px 0;
background: #ffffff;
}
.chat-message {
display: flex;
margin-bottom: 1rem;
align-items: flex-start;
}
.chat-message .avatar {
width: 40px;
height: 40px;
margin-right: 10px;
border-radius: 50%;
}
.chat-message .speech-bubble {
background: #2a2a2a;
padding: 10px 15px;
border-radius: 10px;
max-width: 80%;
margin-bottom: 5px;
}
.chat-message .timestamp {
font-size: 0.8em;
color: #666;
}
.logo-row {
width: 100%;
background-color: #1a1a1a;
padding: 10px 0;
margin: 0;
border-bottom: 1px solid #3a3a3a;
}
"""
def create_ui():
current_dir = os.path.dirname(os.path.abspath(__file__))
logo_path = os.path.join(current_dir, "assets", "intuigence.png")
css = """
/* Theme colors */
:root {
--orange-primary: #ff6b2b;
--orange-hover: #ff8651;
--orange-light: rgba(255, 107, 43, 0.1);
}
/* Logo styling */
.logo-container {
padding: 10px 20px;
margin-bottom: 10px;
text-align: left;
width: 100%;
background: #1a1a1a; /* Match app background */
border-bottom: 1px solid #3a3a3a;
}
.logo-container img {
max-height: 40px;
width: auto;
display: inline-block !important;
}
/* Hide download and fullscreen buttons for logo */
.logo-container .download-button,
.logo-container .fullscreen-button {
display: none !important;
}
/* Adjust main content padding */
.main-content {
padding-top: 10px;
}
/* Custom orange theme */
.primary-button {
background: var(--orange-primary) !important;
color: white !important;
border: none !important;
}
.primary-button:hover {
background: var(--orange-hover) !important;
}
/* Tab styling */
.tabs > .tab-nav > button.selected {
border-color: var(--orange-primary) !important;
color: var(--orange-primary) !important;
}
.tabs > .tab-nav > button:hover {
border-color: var(--orange-hover) !important;
color: var(--orange-hover) !important;
}
/* File upload button */
.file-upload {
background: var(--orange-primary) !important;
}
/* Progress bar */
.progress-bar > div {
background: var(--orange-primary) !important;
}
/* Tags and labels */
.label-wrap {
background: var(--orange-primary) !important;
}
/* Selected/active states */
.selected, .active, .focused {
border-color: var(--orange-primary) !important;
color: var(--orange-primary) !important;
}
/* Links and interactive elements */
a, .link, .interactive {
color: var(--orange-primary) !important;
}
a:hover, .link:hover, .interactive:hover {
color: var(--orange-hover) !important;
}
/* Input focus states */
input:focus, textarea:focus {
border-color: var(--orange-primary) !important;
box-shadow: 0 0 0 1px var(--orange-light) !important;
}
/* Checkbox and radio */
input[type="checkbox"]:checked, input[type="radio"]:checked {
background-color: var(--orange-primary) !important;
border-color: var(--orange-primary) !important;
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
# Logo row (before main content)
with gr.Row(elem_classes="logo-container"):
gr.Image(
value=logo_path,
show_label=False,
container=False,
interactive=False,
show_download_button=False,
show_share_button=False,
height=40
)
# State for storing file path
file_path = gr.State()
json_path = gr.State()
# Main content row
with gr.Row(elem_classes="main-content"):
# Left column - File Upload & Processing
with gr.Column(scale=3, elem_classes="column-panel"):
file_output = gr.File(label="Upload P&ID Document")
process_button = gr.Button(
"Process Document",
elem_classes="primary-button" # Add custom class
)
progress_output = gr.Textbox(
label="Progress",
value="Waiting for document...",
interactive=False
)
# Center column - Preview Panel
with gr.Column(scale=5, elem_classes="column-panel preview-panel"):
with gr.Tabs() as tabs:
with gr.TabItem("P&ID"):
input_image = gr.Image(type="filepath", label="Original")
with gr.TabItem("Symbols"):
symbol_image = gr.Image(type="filepath", label="Detected Symbols")
with gr.TabItem("Tags"):
text_image = gr.Image(type="filepath", label="Detected Tags")
with gr.TabItem("Lines"):
line_image = gr.Image(type="filepath", label="Detected Lines")
with gr.TabItem("Aggregated"):
aggregated_image = gr.Image(type="filepath", label="Aggregated Results")
with gr.TabItem("Knowledge Graph"):
graph_image = gr.Image(type="filepath", label="Knowledge Graph")
# Right column - Chat Interface
with gr.Column(scale=4, elem_classes="column-panel chat-panel", elem_id="chat-panel"):
chat_history = gr.Chatbot(
[],
elem_classes="chat-history",
height=400,
show_label=False,
type="messages",
elem_id="chat-history"
)
with gr.Row():
chat_input = gr.Textbox(
placeholder="Ask me about the P&ID...",
show_label=False,
container=False
)
chat_button = gr.Button(
"Send",
elem_classes="primary-button" # Add custom class
)
def handle_chat(user_message, chat_history, json_path):
if not user_message:
return "", chat_history
# Add user message
chat_history = chat_history + [{"role": "user", "content": user_message}]
try:
# Get assistant response
response = get_assistant_response(user_message, json_path)
# Add assistant response
chat_history = chat_history + [{"role": "assistant", "content": response}]
except Exception as e:
logger.error(f"Error in chat response: {str(e)}")
chat_history = chat_history + [
{"role": "assistant", "content": "I apologize, but I encountered an error processing your request."}
]
return "", chat_history
# Connect UI elements
chat_input.submit(handle_chat, [chat_input, chat_history, json_path], [chat_input, chat_history])
chat_button.click(handle_chat, [chat_input, chat_history, json_path], [chat_input, chat_history])
process_button.click(
process_pnid,
inputs=[file_output],
outputs=[
progress_output, # Progress text (0)
input_image, # P&ID (1)
symbol_image, # Symbols (2)
text_image, # Tags (3)
line_image, # Lines (4)
aggregated_image, # Aggregated (5)
graph_image, # Graph (6)
chat_history, # Chat (7)
json_path # State (8)
],
show_progress="hidden" # Hide progress in tabs
)
return demo
def main():
# Check for all required models
required_models = [
'models/yolo/yolov8n.pt',
'models/deeplsd/deeplsd_md.tar',
'models/doctr/craft_mlt_25k.pth',
'models/doctr/english_g2.pth',
'models/yolo/intui_LDM_01.pt'
]
if any(not os.path.exists(model) for model in required_models):
download_from_azure()
demo = create_ui()
# Remove HF Spaces conditional, just use local development settings
demo.launch(
server_name="0.0.0.0",
server_port=7861, # Changed from 7860
share=True
)
if __name__ == "__main__":
main()