Ryan2219's picture
Update app.py
517bb60 verified
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 5 12:28:34 2025
@author: rmd2219
"""
import os, json, io, re, time
from PIL import Image
import pandas as pd
from google import genai
from google.genai import types
import gradio as gr
import hashlib
import numpy as np
from huggingface_hub import hf_hub_download, HfApi
from huggingface_hub.utils import EntryNotFoundError
USAGE_DATASET_REPO = os.environ.get("USAGE_DATASET_REPO", "NYSERDA-CRE-Working-Group/nyserda_demo_useage_store")
USAGE_FILENAME = os.environ.get("USAGE_FILENAME", "usage.csv")
MAX_RUNS_PER_USER = int(os.environ.get("MAX_RUNS_PER_USER", "10"))
os.environ["GEMINI_API_KEY"] = os.environ.get("GEMINI_API_KEY")
HF_TOKEN = os.environ.get("HF_TOKEN")
api = HfApi(token=HF_TOKEN)
def user_id_from_profile(profile: gr.OAuthProfile | None) -> str | None:
if profile is None:
return None
# You said profile.name exists; normalize it.
# If you later can access preferred_username, use that instead (more unique).
uid = getattr(profile, "name", None)
if not uid:
return None
return uid.strip().lower()
def _load_usage_df() -> pd.DataFrame:
try:
local_path = hf_hub_download(
repo_id=USAGE_DATASET_REPO,
repo_type="dataset",
filename=USAGE_FILENAME,
token=HF_TOKEN,
)
return pd.read_csv(local_path)
except EntryNotFoundError:
# First run: create empty table
return pd.DataFrame(columns=["user_id", "runs", "first_seen", "last_seen"])
def _save_usage_df(df: pd.DataFrame, commit_message: str) -> None:
tmp_path = "/tmp/usage.csv"
df.to_csv(tmp_path, index=False)
api.upload_file(
path_or_fileobj=tmp_path,
path_in_repo=USAGE_FILENAME,
repo_id=USAGE_DATASET_REPO,
repo_type="dataset",
commit_message=commit_message,
)
def check_and_increment_quota(user_id: str) -> tuple[bool, int]:
now = int(time.time())
df = _load_usage_df()
if df.empty or (df["user_id"] == user_id).sum() == 0:
runs = 0
if runs >= MAX_RUNS_PER_USER:
return False, 0
new_row = {
"user_id": user_id,
"runs": 1,
"first_seen": now,
"last_seen": now,
}
df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
_save_usage_df(df, commit_message=f"usage: increment {user_id} to 1")
return True, MAX_RUNS_PER_USER - 1
idx = df.index[df["user_id"] == user_id][0]
runs = int(df.loc[idx, "runs"])
if runs >= MAX_RUNS_PER_USER:
return False, 0
runs += 1
df.loc[idx, "runs"] = runs
df.loc[idx, "last_seen"] = now
_save_usage_df(df, commit_message=f"usage: increment {user_id} to {runs}")
return True, MAX_RUNS_PER_USER - runs
### Load in all preprocessed files
tile_metadata = json.load(open("tile_metadata.json"))
global_context = json.load(open("global_context.json"))
text_list = json.load(open("text_list.json"))
page_metadata = json.load(open("page_metadata.json"))
def load_fullpage_images(folder="Images"):
files = os.listdir(folder)
# Extract only the relevant files
page_files = []
for f in files:
match = re.search(r"page_(\d+)_fullpage\.png", f)
if match:
page_num = int(match.group(1))
page_files.append((page_num, f))
# Sort by extracted page number
page_files.sort(key=lambda x: x[0])
# Load each file as bytes
image_bytes_list = []
for page_num, filename in page_files:
path = os.path.join(folder, filename)
with open(path, "rb") as f:
img_bytes = f.read()
image_bytes_list.append(img_bytes)
return image_bytes_list
def load_cropped_images(folder="Images"):
files = os.listdir(folder)
# Extract only the relevant files
page_files = []
for f in files:
match = re.search(r"page_(\d+)_drawing\.png", f)
if match:
page_num = int(match.group(1))
page_files.append((page_num, f))
# Sort by extracted page number
page_files.sort(key=lambda x: x[0])
# Load each file as bytes
image_bytes_list = {}
for page_num, filename in page_files:
path = os.path.join(folder, filename)
with open(path, "rb") as f:
img_bytes = f.read()
image_bytes_list[page_num] = img_bytes
return image_bytes_list
def load_tile_images(page):
files = os.listdir('Tiles')
# Extract only the relevant files
page_files = []
for f in files:
match = re.search(f"page_{page}_tile_(\d+)\.png", f)
if match:
page_num = int(match.group(1))
page_files.append((page_num, f))
# Sort by extracted page number
page_files.sort(key=lambda x: x[0])
# Load each file as bytes
image_bytes_list = []
for page_num, filename in page_files:
path = os.path.join('Tiles', filename)
with open(path, "rb") as f:
img_bytes = f.read()
image_bytes_list.append(img_bytes)
return image_bytes_list
image_bytes_list = load_fullpage_images()
cropped_bytes_list = load_cropped_images()
tile_bytes = {}
for page in range(44):
tile_list = load_tile_images(page)
if tile_list:
tile_bytes[page] = load_tile_images(page)
class DrawingChatbot:
def __init__(self, model_name, global_context, global_metadata, tile_metadata, text_data, image_data, cropped_image_data, tile_data):
"""
Initializes the Dual-Model Chatbot.
- Model A (Investigator): Forced to use tools to gather data.
- Model B (Analyst): Pure text model that interprets the findings.
"""
self.client = genai.Client()
self.text_data = text_data
self.full_page_data = image_data
self.cropped_image_data = cropped_image_data
self.tiles_data = tile_data
self.tile_metadata = tile_metadata
self.model_name = model_name # "gemini-3-pro-preview"
# ---------------------------------------------------------
# 1. Define Tools (Including the special 'Stop' tool)
# ---------------------------------------------------------
self.tools = [
types.Tool(
function_declarations=[
# Existing Tools
types.FunctionDeclaration(
name="get_page_by_index",
description="Fetch PDF page text and image as well as tile metadata.",
parameters=types.Schema(
type="object",
properties={"index": types.Schema(type="integer")},
required=["index"]
)
),
types.FunctionDeclaration(
name="get_tiles",
description="Fetch tile images by indices.",
parameters=types.Schema(
type="object",
properties={
"page_num": types.Schema(type="integer"),
"indices": types.Schema(type="array", items=types.Schema(type="integer"))
},
required=["page_num", "indices"]
)
),
# NEW CONTROL TOOL: Allows the Investigator to signal it is done
types.FunctionDeclaration(
name="investigation_complete",
description="Call this ONLY when you have gathered sufficient information to answer the user's question.",
parameters=types.Schema(
type="object",
properties={"reason": types.Schema(type="string")},
required=["reason"]
)
)
]
)
]
# ---------------------------------------------------------
# 2. Configure Model A: The Investigator
# ---------------------------------------------------------
investigator_system_prompt = (f'''
You are a Data Retrieval Worker for MEP drawings. You will be gathering data to answer the users question.
Your ONLY job is to call tools to gather information to answer the user's prompt.
You have access to:
- Global context about how these drawings are structured.
- Metadata for every page.
- High-resolution page images and tile images.
Proper Workflow:
1. Consider the users question and examine the page metadata in detail.
- Identify all potentially relevant pages that may assist in your answer
2. Call the pages one or two at a time in order of which one is most likely to be helpful for the users question.
- The pages will return the visual as well as detailed metadata of what tiles ar included in the page and what is in each tile
- Follow up this query by locating the specific tiles releevant to the users question as they are invaluable when answering detailed questions.
- ALWAYS CALL TILES IF YOU CAN FIND RELEVANT ONES TO THE USERS QUESTION.
3. Consider weather or not the previously queried information will be enough to answer the users question. If so call the investigation_complete function, if not return to step 1.
RULES:
1. You CANNOT speak to the user. You can ONLY call tools.
2. You are in FORCE TOOL mode.
3. When you have gathered enough information from pages/tiles, you MUST call 'investigation_complete'.
4. If comparing two states (e.g. Demolition vs New Work), retrieve BOTH before calling complete.
5. Tile images are **MANDATORY** if feasible to provide them. Always request ALL potentially relevant tiles, not just the most relevant.
6. tiles will be stitched together before being passed to the user. If you need to call non-touching tiles call them in seperate tool calls.
- eg. [21,22,23] okay (vertically stacked tiles are fine as well)
- eg [4, 22] not allowed
Global Context: {json.dumps(global_context)}
Page Metadata: {json.dumps(global_metadata)}
**REMEMBER: ALWAYS CALL TILES**
ALso - if you are describing anything visual be sure to search for the legend to explain symbols and callouts. Also supply this image to the analyst.
''')
self.investigator_config = types.GenerateContentConfig(
system_instruction=investigator_system_prompt,
tools=self.tools,
tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(mode="ANY") # <--- FORCED TOOL USE
)
)
self.investigator_chat = self.client.chats.create(
model=self.model_name,
config=self.investigator_config
)
# ---------------------------------------------------------
# 3. Configure Model B: The Analyst
# ---------------------------------------------------------
self.analyst_system_prompt = (f'''
You are an industry expert in architecture, MEP engineering, structural engineering, and construction documentation. (Analyst).
You will receive a user question and a detailed LOG of an investigation performed by a worker.
You will receive stitched together tiles that show a detailed close up view of key areas of the image. These are your superpower, always give them alot of attention.
The results of the get_tiles tool call is one image of all requested tiles stitched together so there are no overlap concerns between tiles.
Tile metadata should be used to help you examine the tiles but the stitched together image should be the ground truth, always use this when counting equipment, examining layouts, etc.
ex. Count peices of equipment.
1. Examine the iamage of the stitched together tile.
- Note unique visual landmarks near each item (e.g., 'near the left door', 'near the right wall').
2. Identify: Identify all potentially relevant markers and symbols and prepare to describe them to the user.
3. Determine: Determine which symbols correspond to unique peices of equipment.
- Cross check across document text given to you for symbol meanings and consistency in your report.
3. Final Count: What is the total number of unique equipment peices in the room?"
Your job is to read the logs (which contain text data and image analysis) and answer the user's question.
Your goals:
1. Give the user an accurate, detailed, spatially-grounded interpretation of the drawings.
2. Provide high-level reasoning steps (visible chain-of-thought) without revealing internal scratchpads.
3. Avoid contradictions in spatial flow analysis (e.g., tracing a duct from a source only to discover later it did not connect).
β†’ Before answering, verify the entire path by examining relevant tiles.
When you receive a question:
1. **Restate** the user question in brief.
2. **Perform spatial reasoning using a structured approach:**
- Identify the component(s).
- Locate their exact position(s) on the drawing using tile metadata (x/y, bounding boxes, labels).
- Verify all upstream/downstream connections before describing them.
- Trace flows fully BEFORE writing your final answer.
- Give the user a detailed description of spatial landmarks that are relavant to what you are discussing so they can reference the drawing themselves while reading your response.
NOTES:
- Do a detailed search of all information given to you for relevant information before answering.
- Attempt to locate a legend or notes list to help you decifer what is going on visually in the drawing.
3. **Answer formatting**:
- **Provide clear references: β€œLocated at the upper-left quadrant near Room 212,” β€œSouth of AHU-1,” etc.
- ALWAYS explain relevant equipment, numbers and locations
- ALWAYS explain the callout or symbol relevant to what you are discussing and describe it visually.
- Never give vague answers eg. 'exisiting ductwork', instead say 'the ductwork leading from x to y is remiainign in place and connected to diffuser z shown as _ symbol in the drawing'.
- This includes symbols (eg. hexagonal, callouts, squares) for equipment, boxes, etc. ALWAYS describe what every relevent symbol looks like.
- Describe flow paths step-by-step in spatial order (e.g., β€œfrom AHU-1 β†’ main supply riser β†’ branch duct β†’ Gymnasium”) This is NOT the actual path followed in the drawing.
- If tiles were inspected, briefly mention which tiles provided key details.
- Always mention both the page index and page name that you can find the information you are discussing on.
***IMPORTANT**
If you are unsure about something say so. Never make assumptions about details. When making statements site where you found it in the drawing, what symbol is relevant, etc.
If somethign is unclear state that.
Here is the original metadata:
Global Context: {json.dumps(global_context)}
Page Metadata: {json.dumps(global_metadata)}
''')
self.analyst_config = types.GenerateContentConfig(
system_instruction=self.analyst_system_prompt,
thinking_config=types.ThinkingConfig(thinking_level="high"),
# No tools for the analyst
)
self.analyst_chat = self.client.chats.create(
model=self.model_name,
config=self.analyst_config
)
self.investigation_logs = []
self.collected_images = []
self.analyst_image_parts = []
self.raw_tool_responses = []
# =========================================================================
# TOOL LOGIC (Unchanged, just internal helpers)
# =========================================================================
def _get_page_by_index(self, index):
if index < 0 or index >= len(self.text_data):
raise ValueError(f"Page index {index} out of bounds.")
try:
return {
"text": f"PAGE {index} TEXT:\n{self.text_data[index]}\nTILE META:\n{json.dumps(self.tile_metadata.get(str(index), {}))}",
"image_bytes": self.cropped_image_data[index],
"mime_type": "image/png"
}
except:
return {
"text": f"PAGE {index} TEXT:\n{self.text_data[index]}\n",
"image_bytes": self.full_page_data[index],
"mime_type": "image/png"
}
def _get_tiles(self, page_num, indicies):
"""
Retrieves specific tiles, stitches them into a single image,
and returns the combined image bytes.
"""
images = []
positions = []
# 1. Collect Image Objects and Coordinates
for index in indicies:
# Check existence in tiles_data
if page_num < len(self.tiles_data) and index < len(self.tiles_data[page_num]):
# Get the bytes
img_bytes = self.tiles_data[page_num][index]
# Load into PIL
img = Image.open(io.BytesIO(img_bytes))
images.append(img)
# Get the coordinates
# ASSUMPTION: self.tile_metadata exists and mirrors tiles_data structure
# Adjust key access (str vs int) based on your specific metadata structure
try:
# Assuming structure: self.tile_metadata[page_num][index]['coords']
# resulting in [x0, y0, x1, y1] or [left, top, width, height]
coords = self.tile_metadata[str(page_num)][str(index)]['coords']
# We only need (left, top) which are usually the first two elements
positions.append((int(coords[0]), int(coords[1])))
except (KeyError, IndexError, AttributeError):
raise ValueError(f"Metadata/Coordinates missing for Tile {index} on Page {page_num}")
else:
raise ValueError(f"Tile {index} on Page {page_num} not found.")
if not images:
return {}
# 2. Normalize Coordinates (Shift to 0,0)
# Find the top-left-most point among the selected tiles
min_x = min(x for x, y in positions)
min_y = min(y for x, y in positions)
# Calculate relative positions based on that origin
normalized_positions = [(x - min_x, y - min_y) for x, y in positions]
# 3. Calculate New Canvas Size
# Width = max(relative_x + image_width)
total_width = max(pos[0] + img.width for pos, img in zip(normalized_positions, images))
total_height = max(pos[1] + img.height for pos, img in zip(normalized_positions, images))
# 4. Create Canvas and Paste
# 'RGB' for standard colors, 'RGBA' if you need transparency
stitched_image = Image.new('RGB', (total_width, total_height), (255, 255, 255))
for img, pos in zip(images, normalized_positions):
stitched_image.paste(img, pos)
output_buffer = io.BytesIO()
stitched_image.save(output_buffer, format='PNG')
stitched_bytes = output_buffer.getvalue()
# 6. Return formatted result
# We return a single result labeled 'stitched' or index 0,
# since the distinct tiles are now one unit.
return {
"stitched_image": {
"image_bytes": stitched_bytes,
"mime_type": "image/png"
}
}
# =========================================================================
# THE DUAL-MODEL LOOP
# =========================================================================
def ask(self, question):
if self.collected_images:
yield (
f"\nπŸš€ STARTING INVESTIGATION PROCESS for: {question}",
self.collected_images, # <-- list of PIL images for gr.Gallery
None # <-- No analyst answer yet
)
else:
yield (
f"\nπŸš€ STARTING INVESTIGATION PROCESS for: {question}",
None, # <-- list of PIL images for gr.Gallery
None # <-- No analyst answer yet
)
# --- PHASE 1: THE INVESTIGATION (Model A) ---
# Start a fresh chat for the investigator
try:
# Send initial message. Since mode="ANY", response MUST be a tool call.
response = self.investigator_chat.send_message(question + "Remember to always ask for ALL tiles associated with the question!")
except Exception as e:
return f"⚠️ Critical Error starting investigation: {e}"
max_steps = 15
step = 0
investigation_done = False
while not investigation_done and step < max_steps:
step += 1
# V3 might generate a thought before the tool call
if response.text:
log_entry = f"πŸ•΅οΈ Investigator Thought: {response.text.strip()}"
self.investigation_logs.append(log_entry)
if not response.function_calls:
# If we are in ANY mode and get no function calls, something broke violently.
self.investigation_logs.append("⚠️ Investigator returned no tools in ANY mode.")
break
parts_to_send_back = []
for tool_call in response.function_calls:
name = tool_call.name
args = tool_call.args
self.investigation_logs.append(f" > Executing Tool: {name}")
# 1. CHECK FOR STOP SIGNAL
if name == "investigation_complete":
self.investigation_logs.append("βœ… Investigator signaled completion.")
self.investigation_logs.append(f"πŸ›‘ ACTION: Investigation Complete. Reason: {args.get('reason')}")
investigation_done = True
break # Break the inner for-loop
# 2. EXECUTE REAL TOOLS
result_log = ""
function_result_dict = {}
image_parts = []
try:
if name == "get_page_by_index":
idx = int(args["index"])
self.investigation_logs.append(f'Fetching page {idx}')
data = self._get_page_by_index(idx)
function_result_dict = {"text": data["text"], "status": "success"}
self.raw_tool_responses.append({
"tool_name": name,
"args": args,
"raw_result": data["text"], # <-- full raw dictionary returned by the tool
})
image_parts.append(types.Part.from_bytes(data=data["image_bytes"], mime_type=data["mime_type"]))
img = Image.open(io.BytesIO(data["image_bytes"]))
self.collected_images.append(img)
self.analyst_image_parts.append(types.Part.from_bytes(data=data["image_bytes"], mime_type=data["mime_type"]))
result_log = f"Fetched Page {idx}. Text length: {len(data['text'])} chars."
elif name == "get_tiles":
p_num = int(args["page_num"])
# Handle list vs single int oddities
raw_indices = args["indices"]
self.investigation_logs.append(f'Fetching tiles {raw_indices} from page {p_num}')
t_indices = [int(raw_indices)] if isinstance(raw_indices, (int, str)) else [int(x) for x in raw_indices]
tile_dict = self._get_tiles(p_num, t_indices)
function_result_dict = {"status": "success", "tiles": list(tile_dict.keys())}
for t_id, t_data in tile_dict.items():
image_parts.append(types.Part.from_bytes(data=t_data["image_bytes"], mime_type=t_data["mime_type"]))
self.analyst_image_parts.append(types.Part.from_bytes(data=t_data["image_bytes"], mime_type=t_data["mime_type"]))
img = Image.open(io.BytesIO(t_data["image_bytes"]))
self.collected_images.append(img)
result_log = f"Fetched Tiles {t_indices} from Page {p_num}."
else:
raise ValueError(f"Unknown tool: {name}")
except Exception as e:
self.investigation_logs.append(f" ❌ Tool Error: {e}")
function_result_dict = {"error": str(e)}
result_log = f"Error executing {name}: {e}"
# Log the action for the Analyst
self.investigation_logs.append(f"πŸ”§ ACTION: Called {name} with {args}")
self.investigation_logs.append(f" RESULT: {result_log}")
# Build response for Investigator
parts_to_send_back.append(
types.Part.from_function_response(
name=name,
response={"result": function_result_dict}
)
)
parts_to_send_back.extend(image_parts)
yield (
"\n".join(self.investigation_logs),
self.collected_images, # <-- list of PIL images for gr.Gallery
None # <-- No analyst answer yet
)
# Check if we broke due to completion
if investigation_done:
break
# Send results back to Investigator to continue loop
try:
response = self.investigator_chat.send_message(parts_to_send_back)
except Exception as e:
yield (
"\n".join(self.investigation_logs),
self.collected_images, # <-- list of PIL images for gr.Gallery
f"⚠️ Investigator API Error: {e}"
)
# --- PHASE 2: THE ANALYSIS (Model B) ---
yield (
"\nπŸ“ Investigation finished. Handing over to Analyst...",
self.collected_images,
None
)
# Build final handoff text
final_prompt = (
f"USER QUESTION: {question}\n\n"
f"=== INVESTIGATION LOGS ===\n"
+ "\n".join(self.investigation_logs) + "\n"
f"==========================\n\n"
f"=== RAW TOOL RESPONSES ===\n"
+ json.dumps(self.raw_tool_responses, indent=2) + "\n"
f"===========================\n\n"
"Based on the logs and raw tool outputs above, provide a detailed answer to the user."
)
print(len(final_prompt))
# ---------------------------------------------------------
# 1. Add a FULL multimodal message to analyst memory
# ---------------------------------------------------------
# Build the message (list of Parts)
self.analyst_contents = [types.Part.from_text(text=final_prompt)]
# Append all collected image parts
self.analyst_contents.extend(self.analyst_image_parts)
try:
# Use the corrected content structure with the send_message fix we discussed
analyst_response = self.analyst_chat.send_message(self.analyst_contents)
yield (
"\n".join(self.investigation_logs),
self.collected_images,
analyst_response.text
)
except Exception as e:
yield (
"\n".join(self.investigation_logs),
self.collected_images,
f"⚠️ Analyst API Error: {e}"
)
#%%
BOT_ARGS = dict(
model_name="gemini-3-pro-preview",
global_context=global_context,
global_metadata=page_metadata,
tile_metadata=tile_metadata,
text_data=text_list,
image_data=image_bytes_list,
cropped_image_data=cropped_bytes_list,
tile_data=tile_bytes
)
# The current bot instance
bot = DrawingChatbot(**BOT_ARGS)
'''
# 3. Start the Conversation Loop
print("--- Construction Bot Ready (Type 'quit' to exit) ---")
while True:
user_input = input("\nYou: ")
if user_input.lower() in ["quit", "exit"]:
break
try:
answer = bot.ask(user_input)
print(f"Bot: {answer}")
except Exception as e:
print(f"Error: {e}")
'''
#%%
def reset_bot():
global bot # reassign the module-level bot
# Recreate a fresh bot instance
bot = DrawingChatbot(**BOT_ARGS)
# Clear outputs: gallery empty list, text empty
return None, "πŸ”„ Bot has been reset. Ask a new question!"
def to_bytes(img):
"""Convert literally any image-like object to bytes safely."""
# --- raw bytes or bytearray ---
if isinstance(img, (bytes, bytearray)):
return bytes(img)
# --- PIL image ---
if isinstance(img, Image.Image):
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
# --- numpy array ---
if isinstance(img, np.ndarray):
pil = Image.fromarray(img)
buf = io.BytesIO()
pil.save(buf, format="PNG")
return buf.getvalue()
# --- dict-like Gemini wrapper ---
if isinstance(img, dict):
# most Gemini image parts are {"bytes": b"..."}
if "bytes" in img:
return img["bytes"]
# --- final fallback: encode repr(obj) ---
return repr(img).encode("utf-8")
def hash_bytes(b):
"""Hash only after ensuring `b` is truly bytes."""
return hashlib.md5(b).hexdigest()
# ---- Wrapping bot.ask() so Gradio can use the generator directly ----
def run_investigation(question, profile: gr.OAuthProfile | None):
"""
Streams gallery + text while preventing duplicates.
"""
uid = user_id_from_profile(profile)
if uid is None:
raise gr.Error("Please sign in with Hugging Face to use this demo.")
allowed, remaining = check_and_increment_quota(uid)
if not allowed:
raise gr.Error(f"Usage limit reached: {MAX_RUNS_PER_USER} runs per user.")
if remaining <= 2:
gr.Warning(f"⚠️ Only {remaining} run(s) left!")
else:
gr.Info(f"βœ“ Runs remaining: {remaining}")
all_images = []
seen = set()
for logs, images, answer in bot.ask(question):
if images:
for img in images:
# ALWAYS convert to bytes first.
img_bytes = to_bytes(img)
# Now compute hash of true bytes.
h = hash_bytes(img_bytes)
# Add only if unique
if h not in seen:
seen.add(h)
all_images.append(img)
text = answer if answer else logs
yield all_images, text
# ------------------- GRADIO UI ----------------------
custom_css = """
.custom-btn,
.custom-btn button,
button.custom-btn {
border: 2px solid #222 !important;
background: #FFF8DC !important;
color: black !important;
border-radius: 6px !important;
font-weight: 600 !important;
padding: 6px 14px !important;
box-shadow: 0 0 0 1px #ccc !important;
transition: 0.2s ease-in-out;
}
.custom-btn:hover,
button.custom-btn:hover {
border-color: #2563eb !important;
background: white !important;
box-shadow: 0px 0px 6px rgba(37, 99, 235, 0.5) !important;
color: black !important;
}
"""
with gr.Blocks(css=custom_css) as demo:
gr.LoginButton()
with gr.Group(visible=True) as main_app:
gr.Markdown("""
# πŸ—οΈ MEP Drawing Investigator
Ask a question about the drawing below.
Retreived Images will appear live as they are retrieved.
""")
reset_btn = gr.Button("πŸ” Reset Bot", elem_classes="custom-btn")
gr.DownloadButton(
label="πŸ“„ Download Drawings PDF",
value="NorthMaconPark.pdf",
elem_classes="custom-btn"
)
# Input
question = gr.Textbox(
label="Your Question",
placeholder="e.g., 'Trace the supply duct to Room 204'",
lines=2
)
submit_btn = gr.Button("πŸ” Investigate", elem_classes="custom-btn")
# --- GALLERY DISPLAY ---
image_display = gr.Gallery(
label="πŸ“Έ Drawing Tiles (Live)",
elem_id="image-pane",
show_label=True,
preview=True, # enables swapping/preview
columns=3, # display grid
height="auto"
)
# --- TEXT WINDOW ---
text_display = gr.Markdown(
label="πŸ“ Investigation Logs / Final Answer",
elem_id="text_display"
)
# Clicking the button launches streaming
submit_btn.click(
fn=run_investigation,
inputs=question,
outputs=[image_display, text_display]
)
reset_btn.click(
fn=reset_bot,
inputs=None,
outputs=[image_display, text_display]
)
demo.launch()