hmgill commited on
Commit
92e2d37
Β·
verified Β·
1 Parent(s): 535d69d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import asyncio
3
+ import os
4
+ import glob
5
+ import torch
6
+ from pathlib import Path
7
+ from PIL import Image
8
+
9
+ # GenAI & ADK Imports
10
+ from google.adk.runners import InMemoryRunner
11
+ from google.genai import types
12
+
13
+ # Project Imports
14
+ from cellemetry import root_agent
15
+ from cellemetry.config import AnalysisDeps
16
+ from transformers import Sam3Processor, Sam3Model
17
+
18
+ # --- Global State for Heavy Models ---
19
+ # We load the model once when the app starts to avoid reloading per request.
20
+ MODEL_CACHE = {
21
+ "model": None,
22
+ "processor": None,
23
+ "device": "cpu"
24
+ }
25
+
26
+ def load_models():
27
+ """Initialize SAM3 model."""
28
+ if MODEL_CACHE["model"] is not None:
29
+ return
30
+
31
+ print("--- Loading SAM3 Model ---")
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ MODEL_CACHE["device"] = device
34
+
35
+ try:
36
+ # Note: Ensure you have access to facebook/sam3 or use a public alternative
37
+ MODEL_CACHE["model"] = Sam3Model.from_pretrained("facebook/sam3").to(device)
38
+ MODEL_CACHE["processor"] = Sam3Processor.from_pretrained("facebook/sam3")
39
+ print(f"βœ… SAM3 loaded on {device}")
40
+ except Exception as e:
41
+ print(f"⚠️ SAM3 load failed (using mock): {e}")
42
+ # Allow app to start even if model fails (will fall back to mock logic if implemented)
43
+
44
+ # Load immediately on startup
45
+ load_models()
46
+
47
+ # --- Core Analysis Function ---
48
+ async def run_analysis(image_path_str, user_prompt, progress=gr.Progress()):
49
+ """
50
+ Async generator that runs the agent and yields updates to the UI.
51
+ """
52
+ if not image_path_str:
53
+ yield "⚠️ Please upload an image first.", None, None
54
+ return
55
+
56
+ # Clean up previous runs
57
+ for f in glob.glob("/tmp/out_*.png") + glob.glob("/tmp/data_*.npz") + glob.glob("/tmp/*.xlsx"):
58
+ try:
59
+ os.remove(f)
60
+ except:
61
+ pass
62
+
63
+ image_path = Path(image_path_str)
64
+
65
+ # 1. Setup Dependencies
66
+ deps = AnalysisDeps(
67
+ sam_model=MODEL_CACHE["model"],
68
+ sam_processor=MODEL_CACHE["processor"],
69
+ image_path=image_path,
70
+ device=MODEL_CACHE["device"],
71
+ pixel_size_microns=None # Agent will parse this from prompt
72
+ )
73
+
74
+ # 2. Initialize Runner
75
+ runner = InMemoryRunner(agent=root_agent, app_name="cellemetry_demo")
76
+ session = await runner.session_service.create_session(
77
+ app_name="cellemetry_demo",
78
+ user_id="demo_user",
79
+ state=deps.to_state_dict()
80
+ )
81
+
82
+ # 3. Prepare Content
83
+ image_bytes = image_path.read_bytes()
84
+ content = types.Content(
85
+ role="user",
86
+ parts=[
87
+ types.Part.from_text(text=user_prompt),
88
+ types.Part.from_bytes(data=image_bytes, mime_type="image/png"),
89
+ ]
90
+ )
91
+
92
+ # 4. Stream Execution
93
+ logs = [f"πŸš€ Starting analysis on {MODEL_CACHE['device']}..."]
94
+ log_text = "\n".join(logs)
95
+ yield log_text, None, None
96
+
97
+ async for event in runner.run_async(
98
+ user_id="demo_user",
99
+ session_id=session.id,
100
+ new_message=content,
101
+ ):
102
+ author = event.author
103
+
104
+ # Capture Tool Calls
105
+ if event.get_function_calls():
106
+ for fc in event.get_function_calls():
107
+ logs.append(f"πŸ”§ **{author}** calling tool: `{fc.name}`")
108
+
109
+ # Capture Text (Streaming)
110
+ if event.content and event.content.parts:
111
+ for part in event.content.parts:
112
+ if hasattr(part, 'text') and part.text:
113
+ if event.partial:
114
+ # Update the last log line if it's the same thought
115
+ if logs[-1].startswith(f"πŸ’¬ **{author}**"):
116
+ logs[-1] = f"πŸ’¬ **{author}**: {part.text}..."
117
+ else:
118
+ logs.append(f"πŸ’¬ **{author}**: {part.text}...")
119
+ else:
120
+ logs.append(f"βœ… **{author}**: {part.text}")
121
+
122
+ # Yield updated logs immediately
123
+ yield "\n\n".join(logs), None, None
124
+
125
+ # 5. Retrieve Final Results
126
+ logs.append("\n🏁 **Analysis Complete.** gathering files...")
127
+ yield "\n\n".join(logs), None, None
128
+
129
+ # Collect output images
130
+ output_images = glob.glob("/tmp/out_*.png")
131
+
132
+ # Collect excel report
133
+ excel_files = glob.glob("/tmp/*.xlsx")
134
+ report_file = excel_files[0] if excel_files else None
135
+
136
+ logs.append(f"\nπŸ“‹ Found {len(output_images)} segmentation maps and {1 if report_file else 0} report.")
137
+ yield "\n\n".join(logs), output_images, report_file
138
+
139
+
140
+ # --- Gradio UI Layout ---
141
+ with gr.Blocks(title="Cellemetry Agent", theme=gr.themes.Soft()) as demo:
142
+ gr.Markdown("# πŸ”¬ Cellemetry: Agentic Microscopy Analysis")
143
+ gr.Markdown("Upload a microscopy image and ask the agent to identify, segment, and quantify biological structures.")
144
+
145
+ with gr.Row():
146
+ with gr.Column(scale=1):
147
+ # Input Section
148
+ img_input = gr.Image(type="filepath", label="Microscopy Image", height=300)
149
+
150
+ prompt_input = gr.Textbox(
151
+ label="Analysis Request",
152
+ lines=3,
153
+ value="Identify the green irregular cells and blue round nuclei. Provide a statistical report on morphology and density.",
154
+ placeholder="E.g., 'Find all red cells and calculate their density.'"
155
+ )
156
+
157
+ with gr.Accordion("Advanced Settings", open=False):
158
+ gr.Markdown(f"Running on: **{MODEL_CACHE['device']}**")
159
+
160
+ run_btn = gr.Button("πŸ§ͺ Run Analysis", variant="primary", size="lg")
161
+
162
+ with gr.Column(scale=2):
163
+ # Output Section
164
+ with gr.Tabs():
165
+ with gr.Tab("Live Agent Logs"):
166
+ # Markdown component to render bolding/formatting in logs
167
+ log_output = gr.Markdown(label="Agent Thought Process", height=500)
168
+
169
+ with gr.Tab("Visual Results"):
170
+ gallery = gr.Gallery(label="Segmentation Maps", columns=2)
171
+
172
+ with gr.Tab("Data Report"):
173
+ file_output = gr.File(label="Download Excel Report")
174
+
175
+ # Connect the Async Function
176
+ run_btn.click(
177
+ fn=run_analysis,
178
+ inputs=[img_input, prompt_input],
179
+ outputs=[log_output, gallery, file_output]
180
+ )
181
+
182
+ if __name__ == "__main__":
183
+ demo.queue().launch()