File size: 5,535 Bytes
b5f78a6
 
 
 
 
 
 
 
bddd1de
 
 
 
 
 
 
c2c5906
bddd1de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import subprocess
import sys
import os

# Install requirements if needed
requirements_file = os.path.join(os.path.dirname(__file__), "requirements.txt")
subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", requirements_file])

import gradio as gr
from utils.routing import route_agent
from agents.agent1_image_issue import handle_image_issue
from agents.agent2_tenancy_faq import handle_tenancy_query
from PIL import Image
import torch
import hashlib


# Helper to generate MD5 hash from image
def get_image_hash(image):
    return hashlib.md5(image.tobytes()).hexdigest()

# Main query handler
def handle_query(user_input, image=None, location='', history=[], context={}):
    try:
        response_ui_msg = ""

        # Initialize context if missing
        context.setdefault("images", [])
        context.setdefault("image_hashes", [])
        context.setdefault("last_agent", None)
        context.setdefault("last_caption_data", None)

        # If there's a new image
        if image is not None:
            new_hash = get_image_hash(image)
            if len(context["image_hashes"]) == 0 or new_hash != context["image_hashes"][-1]:
                context["images"].append(image)
                context["image_hashes"].append(new_hash)
                context["location"] = ""
                context["last_caption_data"] = None  # Reset cached caption
                response_ui_msg = "(New image attached. Starting image-related discussion.)"

        # If image is removed mid-chat
        if image is None and context["images"]:
            response_ui_msg = "(Image removed. Continuing as text-only query.)"

        # Use location if no image context
        if not context["images"] and location:
            context["location"] = location

        # Determine which agent should handle the query
        is_image_context = bool(context["images"])
        agent = route_agent(user_input, is_image_context)

        # Agent switch handling
        if context["last_agent"] == 'agent1' and agent == 'agent2':
            response_ui_msg += "\n(Switching to tenancy discussion based on your query...)"

        elif context["last_agent"] == 'agent2' and agent == 'agent1':
            response_ui_msg += "\n(Detected switch to image-based issue. Starting a new conversation...)"
            history.clear()
            context.clear()
            context["images"] = [image] if image else []
            context["image_hashes"] = [get_image_hash(image)] if image else []
            context["last_caption_data"] = None
            context["last_agent"] = None
            context["location"] = location or ""

        # Update current agent
        context["last_agent"] = agent

        # Run the correct agent
        if agent == 'agent1':
            if context["images"]:
                result = handle_image_issue(user_input, context["images"][-1], history, context)
            else:
                result = "No image found to analyze."
        else:
            result = handle_tenancy_query(user_input, {"location": context.get("location")}, history)

        # Add message to response
        if response_ui_msg:
            result = f"{response_ui_msg}\n\n{result}"

        history.append((user_input, result))
        return result, history, context, "๐ŸŸข Chat Ongoing"

    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            error_msg = "โš ๏ธ CUDA Out of Memory! Please try again later or reduce the image size."
            return error_msg, history, context, "๐Ÿ”ด Error"
        else:
            raise e

# Reset function
def reset_chat():
    return "", "", None, [], {"location": "", "images": [], "image_hashes": []}, "๐ŸŸก New Chat Started", ""

# Clear just the conversation history
def clear_chat_history():
    return [], "", "๐Ÿงน Chat history cleared"

# Build the Gradio interface
with gr.Blocks() as demo:
    conversation_history = gr.State([])
    user_context = gr.State({"location": "", "images": [], "image_hashes": []})
    session_state = gr.State("๐ŸŸก New Chat Started")

    gr.Markdown("# ๐Ÿ  Multi-Agent Real Estate Chatbot")
    gr.Markdown("Ask about property issues (with images) or tenancy questions!")

    with gr.Row():
        with gr.Column():
            user_input = gr.Textbox(label="Enter your question:")
            location_input = gr.Textbox(label="Enter your city or country (optional):")
            image_input = gr.Image(type="pil", label="Upload an image (optional):")

            submit_btn = gr.Button("Submit")
            new_chat_btn = gr.Button("๐Ÿ” Start New Chat")
            clear_history_btn = gr.Button("๐Ÿงน Clear Chat History")

        with gr.Column():
            chatbot_output = gr.Textbox(label="Chatbot Response:", interactive=False, lines=8)
            session_indicator = gr.Textbox(label="Session Status", interactive=False)

    # Hook button logic
    submit_btn.click(
        handle_query,
        inputs=[user_input, image_input, location_input, conversation_history, user_context],
        outputs=[chatbot_output, conversation_history, user_context, session_indicator]
    )

    new_chat_btn.click(
        reset_chat,
        inputs=[],
        outputs=[user_input, location_input, image_input, conversation_history, user_context, session_indicator, chatbot_output]
    )

    clear_history_btn.click(
        clear_chat_history,
        inputs=[],
        outputs=[conversation_history, chatbot_output, session_indicator]
    )

# Launch app
demo.launch(share=True)