samwell commited on
Commit
879e071
·
1 Parent(s): 86d018b

Use custom Blocks interface with dedicated image output for grounding visualizations

Browse files
Files changed (1) hide show
  1. app.py +47 -9
app.py CHANGED
@@ -66,6 +66,7 @@ agent = Agent(
66
  print(f"Tools loaded: {len(tools)}")
67
 
68
  import glob
 
69
 
70
  def chat(message, history):
71
  config = {"configurable": {"thread_id": "default"}}
@@ -88,23 +89,60 @@ def chat(message, history):
88
  assistant_message = response["messages"][-1].content
89
 
90
  # Check for grounding visualization images
 
91
  viz_files = glob.glob("temp/grounding_*.png")
92
  if viz_files:
93
  # Get the most recent visualization
94
  viz_files.sort(key=os.path.getmtime, reverse=True)
95
  latest_viz = viz_files[0]
 
 
 
 
96
 
97
- # Append image reference that Gradio can render
98
- assistant_message += f"\n\n**Grounding Visualization:**\n\n<img src='/file={latest_viz}' style='max-width: 600px;'/>"
99
 
100
- return assistant_message
 
 
101
 
102
- demo = gr.ChatInterface(
103
- fn=chat,
104
- title="MedRAX2 - Medical AI Assistant",
105
- description=f"Device: {device} | Tools: {len(tools)} loaded",
106
- multimodal=True
107
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  if __name__ == "__main__":
110
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
66
  print(f"Tools loaded: {len(tools)}")
67
 
68
  import glob
69
+ from PIL import Image
70
 
71
  def chat(message, history):
72
  config = {"configurable": {"thread_id": "default"}}
 
89
  assistant_message = response["messages"][-1].content
90
 
91
  # Check for grounding visualization images
92
+ viz_image = None
93
  viz_files = glob.glob("temp/grounding_*.png")
94
  if viz_files:
95
  # Get the most recent visualization
96
  viz_files.sort(key=os.path.getmtime, reverse=True)
97
  latest_viz = viz_files[0]
98
+ try:
99
+ viz_image = Image.open(latest_viz)
100
+ except:
101
+ pass
102
 
103
+ return assistant_message, viz_image
 
104
 
105
+ # Custom interface with image output
106
+ with gr.Blocks() as demo:
107
+ gr.Markdown(f"# MedRAX2 - Medical AI Assistant\n**Device:** {device} | **Tools:** {len(tools)} loaded")
108
 
109
+ chatbot = gr.Chatbot(type="messages")
110
+ viz_output = gr.Image(label="Grounding Visualization", visible=True)
111
+
112
+ msg = gr.MultimodalTextbox(
113
+ label="Message",
114
+ placeholder="Upload an X-ray image and ask a question...",
115
+ file_types=["image"]
116
+ )
117
+
118
+ def respond(message, chat_history):
119
+ # Get response and visualization
120
+ bot_message, viz_image = chat(message, chat_history)
121
+
122
+ # Add to chat history
123
+ if isinstance(message, dict):
124
+ text = message.get("text", "")
125
+ files = message.get("files", [])
126
+ if files:
127
+ chat_history.append({"role": "user", "content": [{"type": "image", "image": files[0]}, {"type": "text", "text": text}]})
128
+ else:
129
+ chat_history.append({"role": "user", "content": text})
130
+ else:
131
+ chat_history.append({"role": "user", "content": message})
132
+
133
+ chat_history.append({"role": "assistant", "content": bot_message})
134
+
135
+ return "", chat_history, viz_image
136
+
137
+ msg.submit(respond, [msg, chatbot], [msg, chatbot, viz_output])
138
+
139
+ gr.Examples(
140
+ examples=[
141
+ [{"text": "What do you see in this X-ray?", "files": []}],
142
+ [{"text": "Can you show me where exactly using grounding?", "files": []}],
143
+ ],
144
+ inputs=msg,
145
+ )
146
 
147
  if __name__ == "__main__":
148
  demo.launch(server_name="0.0.0.0", server_port=7860)