hmgill commited on
Commit
201a9d0
·
verified ·
1 Parent(s): 4290eac

Upload 15 files

Browse files
cellemetry/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cellemetry: Google ADK Agent for Microscopy Image Analysis
3
+ """
4
+ from .agents import root_agent, analyst_agent, manager_agent
5
+ from .config import AnalysisDeps, AnalystResult, ManagerSummary, ComponentRequest, BoundingBox
6
+ from .tools import ANALYST_TOOLS
7
+
8
+ __all__ = [
9
+ "root_agent",
10
+ "analyst_agent",
11
+ "manager_agent",
12
+ "AnalysisDeps",
13
+ "AnalystResult",
14
+ "ManagerSummary",
15
+ "ComponentRequest",
16
+ "BoundingBox",
17
+ "ANALYST_TOOLS",
18
+ ]
cellemetry/agent.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agent definition file for ADK CLI (adk web / adk run).
3
+ Place this at the package root for: adk web bio_agent
4
+ """
5
+ from cellemetry.agents import root_agent
6
+
7
+ # ADK CLI looks for 'root_agent' or 'agent' at module level
8
+ agent = root_agent
cellemetry/agents/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Agents package for bio_agent.
3
+ Exports the agent hierarchy with manager as root.
4
+ """
5
+ from .analyst import analyst_agent
6
+ from .manager import manager_agent
7
+
8
+ # The root agent for ADK runner
9
+ root_agent = manager_agent
10
+
11
+ __all__ = [
12
+ "root_agent",
13
+ "manager_agent",
14
+ "analyst_agent",
15
+ ]
cellemetry/agents/analyst.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analyst Agent - Expert microscopy image analyst.
3
+ Segments biological components and computes statistics.
4
+ """
5
+ from google.adk.agents import LlmAgent
6
+ from google.adk.tools.agent_tool import AgentTool
7
+
8
+ from ..tools import ANALYST_TOOLS
9
+
10
+
11
+ ANALYST_INSTRUCTION = """
12
+ You are an expert microscopy image analyst.
13
+
14
+ **Your Goal:** Identify major biological components, segment them using SAM3, analyze the segmentations, and provide a report.
15
+
16
+ **Step 1: Resolution Parsing:**
17
+ Look for physical resolution info (e.g., "0.27 microns/px", "0.5 um per pixel").
18
+ If found, note it for reference. If not found, proceed without physical units.
19
+
20
+ **Step 2: Visual Analysis**
21
+ Identify distinct structures (e.g., Nuclei, Cells) in the image.
22
+
23
+ **Step 3: Define Tool Inputs**
24
+ Decompose each structure into three words:
25
+ - `color`: ONE adjective (e.g., "green")
26
+ - `morphology`: ONE adjective (e.g., "irregular")
27
+ - `entity`: ONE noun (e.g., "cell" - singular!)
28
+
29
+ **Step 4: Box Selection**
30
+ Select 1-3 representative bounding boxes per structure (0-1000 normalized).
31
+ Ensure boxes cover the full object.
32
+
33
+ **Step 5: Execution**
34
+ Call `apply_sam3_tool` for each structure type.
35
+
36
+ **Step 6: CRITICAL - Use Exact Filenames**
37
+ The segmentation tool returns a result containing "MASK_FILE=/tmp/data_xxx.npz".
38
+ You MUST extract and use this EXACT filename when calling statistics tools.
39
+
40
+ Example:
41
+ - Segmentation returns: "SUCCESS: Found 15 'green irregular cell' objects. MASK_FILE=/tmp/data_green_cell.npz"
42
+ - When calling get_basic_stats, use filename="/tmp/data_green_cell.npz" (the EXACT path from MASK_FILE)
43
+
44
+ **Step 7: Quantification**
45
+ - Call `get_basic_stats` with the EXACT filename from segmentation for every structure found.
46
+ - Call `get_spatial_stats` with the EXACT filename specifically for Cells.
47
+ - Call `get_relationship_stats` with BOTH exact filenames ONLY if both Cells and Nuclei were found.
48
+
49
+ **Step 8: Save Results**
50
+ Save all data using `save_excel_tool`.
51
+
52
+ Return your findings as structured data including:
53
+ - pixel_size_used (if applicable)
54
+ - components_found (list of segmented components)
55
+ - excel_path
56
+ - stats objects
57
+ """
58
+
59
+ analyst_agent = LlmAgent(
60
+ name="analyst",
61
+ model="gemini-3-pro-preview",
62
+ description="Expert microscopy analyst that segments and quantifies biological structures.",
63
+ instruction=ANALYST_INSTRUCTION,
64
+ tools=ANALYST_TOOLS,
65
+ output_key="analyst_result",
66
+ )
cellemetry/agents/manager.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Manager Agent - Workflow orchestrator.
3
+ Coordinates analysis tasks and synthesizes user-facing reports.
4
+ """
5
+ from google.adk.agents import LlmAgent
6
+ from google.adk.tools.agent_tool import AgentTool
7
+
8
+ from .analyst import analyst_agent
9
+
10
+
11
+ MANAGER_INSTRUCTION = """
12
+ You are the Cellemetry Workflow Manager.
13
+
14
+ **Goal**: Orchestrate microscopy image analysis and deliver user-friendly summaries.
15
+
16
+ **Workflow:**
17
+ 1. Receive the user's request and image context.
18
+ 2. Extract resolution info (e.g., "0.27 microns/px") if present in the request.
19
+ 3. Delegate analysis to the `analyst` tool - pass the original request along with any extracted metadata.
20
+ 4. Receive the structured analysis results.
21
+ 5. Synthesize a human-readable summary:
22
+ - Write a clear executive summary
23
+ - Highlight key biological findings (density, size, relationships)
24
+ - List where output files were saved
25
+
26
+ **Important**: When calling the analyst tool, pass the full user request so the analyst has all context about what to analyze.
27
+ """
28
+
29
+ # Wrap analyst as a tool for the manager
30
+ analyst_tool = AgentTool(agent=analyst_agent)
31
+
32
+ manager_agent = LlmAgent(
33
+ name="manager",
34
+ model="gemini-2.5-pro",
35
+ description="Orchestrates microscopy analysis workflows and synthesizes reports.",
36
+ instruction=MANAGER_INSTRUCTION,
37
+ tools=[analyst_tool],
38
+ output_key="manager_summary",
39
+ )
cellemetry/config/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Config package - schemas and dependency definitions.
3
+ """
4
+ from .schemas import (
5
+ AnalystResult,
6
+ ManagerSummary,
7
+ ComponentRequest,
8
+ BoundingBox,
9
+ )
10
+ from .dependencies import AnalysisDeps, get_deps_from_state
11
+
12
+ __all__ = [
13
+ "AnalystResult",
14
+ "ManagerSummary",
15
+ "ComponentRequest",
16
+ "BoundingBox",
17
+ "AnalysisDeps",
18
+ "get_deps_from_state",
19
+ ]
cellemetry/config/dependencies.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dependency management using ADK session state.
3
+ SAM models and shared resources are stored in state for tool access.
4
+ """
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import Optional, Any
8
+
9
+ @dataclass
10
+ class AnalysisDeps:
11
+ """Container for analysis dependencies - stored in session state."""
12
+ sam_model: Any
13
+ sam_processor: Any
14
+ image_path: Path
15
+ device: str
16
+ pixel_size_microns: Optional[float] = None
17
+
18
+ def to_state_dict(self) -> dict:
19
+ """Convert to dict for session state storage."""
20
+ return {
21
+ "app:sam_model": self.sam_model,
22
+ "app:sam_processor": self.sam_processor,
23
+ "app:image_path": str(self.image_path),
24
+ "app:device": self.device,
25
+ "app:pixel_size_microns": self.pixel_size_microns,
26
+ }
27
+
28
+ def get_deps_from_state(state: dict) -> AnalysisDeps:
29
+ """Reconstruct AnalysisDeps from session state."""
30
+ return AnalysisDeps(
31
+ sam_model=state.get("app:sam_model"),
32
+ sam_processor=state.get("app:sam_processor"),
33
+ image_path=Path(state.get("app:image_path")),
34
+ device=state.get("app:device", "cpu"),
35
+ pixel_size_microns=state.get("app:pixel_size_microns"),
36
+ )
cellemetry/config/schemas.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic schemas for structured inputs/outputs.
3
+ ADK supports Pydantic models for structured output via output_schema.
4
+ """
5
+ from pydantic import BaseModel, Field
6
+ from typing import List, Optional, Dict
7
+
8
+
9
+ # --- SAM3 Inputs ---
10
+ class BoundingBox(BaseModel):
11
+ ymin: int = Field(description="Top Y (0-1000)")
12
+ xmin: int = Field(description="Left X (0-1000)")
13
+ ymax: int = Field(description="Bottom Y (0-1000)")
14
+ xmax: int = Field(description="Right X (0-1000)")
15
+
16
+
17
+ class ComponentRequest(BaseModel):
18
+ entity: str = Field(description="Generic object name (e.g., 'cell'). 1 word.")
19
+ color: str = Field(description="Dominant color adjective (e.g., 'green').")
20
+ morphology: str = Field(description="Dominant shape adjective (e.g., 'irregular').")
21
+ bboxes: List[BoundingBox]
22
+
23
+
24
+ # --- Stats Schemas ---
25
+ class BasicStats(BaseModel):
26
+ count: int
27
+ area_mean: float
28
+ area_std: float
29
+ unit: str = "px²"
30
+
31
+
32
+ class SpatialStats(BaseModel):
33
+ avg_nnd: float
34
+ std_nnd: float
35
+ density: float
36
+ avg_neighbor_count: float
37
+ std_neighbor_count: float
38
+ dist_unit: str = "px"
39
+ density_unit: str = "N/A"
40
+
41
+
42
+ class RelationalStats(BaseModel):
43
+ matched_pairs: int
44
+ avg_ratio: float
45
+ std_ratio: float
46
+
47
+
48
+ # --- Analyst Output ---
49
+ class SegmentedComponent(BaseModel):
50
+ label: str
51
+ description: str
52
+ mask_filename: str
53
+ data_filename: str
54
+ count: int
55
+
56
+
57
+ class AnalystResult(BaseModel):
58
+ """Structured output from the Analyst agent."""
59
+ pixel_size_used: Optional[float] = None
60
+ components_found: List[SegmentedComponent] = Field(default_factory=list)
61
+ excel_path: str = ""
62
+ cell_stats: Optional[BasicStats] = None
63
+ nuclei_stats: Optional[BasicStats] = None
64
+ spatial_stats: Optional[SpatialStats] = None
65
+ relational_stats: Optional[RelationalStats] = None
66
+
67
+
68
+ # --- Manager Output ---
69
+ class ManagerSummary(BaseModel):
70
+ """Final user-facing summary from the Manager."""
71
+ executive_summary: str
72
+ key_findings: List[str]
73
+ file_locations: Dict[str, str] = Field(
74
+ default_factory=dict,
75
+ description="Map of 'description' to 'filepath'"
76
+ )
cellemetry/services/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Services package - core processing logic.
3
+ """
4
+ from . import sam
5
+ from . import analysis
6
+
7
+ __all__ = ["sam", "analysis"]
cellemetry/services/analysis.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Statistical analysis functions for segmentation masks.
3
+ Unchanged from original implementation.
4
+ """
5
+ import os
6
+ import xlsxwriter
7
+ import numpy as np
8
+ from scipy.spatial import KDTree
9
+ from skimage.measure import regionprops
10
+ from typing import Optional, Dict, Any
11
+
12
+
13
+ def load_masks(filename: str) -> Optional[np.ndarray]:
14
+ """Load .npz mask stack from disk."""
15
+ try:
16
+ data = np.load(filename)
17
+ return data['masks'] if 'masks' in data else data[data.files[0]]
18
+ except Exception as e:
19
+ print(f"Error loading {filename}: {e}")
20
+ return None
21
+
22
+
23
+ def get_basic_stats(filename: str, pixel_scale: Optional[float] = None) -> Dict[str, Any]:
24
+ """Calculate Count, Mean Area, and Std Dev Area."""
25
+ masks = load_masks(filename)
26
+ if masks is None or masks.size == 0:
27
+ return {"count": 0, "area_mean": 0.0, "area_std": 0.0, "unit": "px²"}
28
+
29
+ areas_px = np.sum(masks, axis=(1, 2))
30
+
31
+ if pixel_scale:
32
+ conversion_factor = pixel_scale ** 2
33
+ areas = areas_px * conversion_factor
34
+ unit = "µm²"
35
+ else:
36
+ areas = areas_px
37
+ unit = "px²"
38
+
39
+ return {
40
+ "count": int(len(areas)),
41
+ "area_mean": float(np.mean(areas)),
42
+ "area_std": float(np.std(areas)),
43
+ "unit": unit
44
+ }
45
+
46
+
47
+ def get_spatial_stats(filename: str, pixel_scale: Optional[float] = None) -> Dict[str, Any]:
48
+ """Calculate spatial metrics including NND and density."""
49
+ masks = load_masks(filename)
50
+ defaults = {
51
+ "avg_nnd": 0.0, "std_nnd": 0.0,
52
+ "density": 0.0,
53
+ "avg_neighbor_count": 0.0, "std_neighbor_count": 0.0,
54
+ "dist_unit": "px", "density_unit": "N/A"
55
+ }
56
+
57
+ if masks is None or masks.size == 0:
58
+ return defaults
59
+
60
+ # Get centroids
61
+ centroids_px = []
62
+ for m in masks:
63
+ props = regionprops(m.astype(int))
64
+ if props:
65
+ centroids_px.append(props[0].centroid)
66
+
67
+ centroids_px = np.array(centroids_px)
68
+ n_objects = len(centroids_px)
69
+ image_pixel_area = masks[0].size
70
+
71
+ # Unit conversion
72
+ if pixel_scale:
73
+ dist_factor = pixel_scale
74
+ dist_unit = "µm"
75
+ image_phys_area = (image_pixel_area * (pixel_scale ** 2)) / (1000 ** 2)
76
+ density_unit = "cells/mm²"
77
+ else:
78
+ dist_factor = 1.0
79
+ dist_unit = "px"
80
+ image_phys_area = image_pixel_area / 10000.0
81
+ density_unit = "cells/10k px²"
82
+
83
+ if n_objects < 2:
84
+ res = defaults.copy()
85
+ res.update({
86
+ "density": float(n_objects / image_phys_area) if image_phys_area > 0 else 0,
87
+ "dist_unit": dist_unit,
88
+ "density_unit": density_unit
89
+ })
90
+ return res
91
+
92
+ # KDTree calculations
93
+ tree = KDTree(centroids_px)
94
+
95
+ # Nearest neighbor distance
96
+ dists_px, _ = tree.query(centroids_px, k=2)
97
+ valid_dists_px = dists_px[:, 1]
98
+
99
+ # Local crowding
100
+ neighbors = tree.query_ball_point(centroids_px, r=100)
101
+ neighbor_counts = [len(n) - 1 for n in neighbors]
102
+
103
+ return {
104
+ "avg_nnd": float(np.mean(valid_dists_px) * dist_factor),
105
+ "std_nnd": float(np.std(valid_dists_px) * dist_factor),
106
+ "density": float(n_objects / image_phys_area),
107
+ "avg_neighbor_count": float(np.mean(neighbor_counts)),
108
+ "std_neighbor_count": float(np.std(neighbor_counts)),
109
+ "dist_unit": dist_unit,
110
+ "density_unit": density_unit
111
+ }
112
+
113
+
114
+ def analyze_relationships(cell_file: str, nuc_file: str) -> Dict[str, Any]:
115
+ """Calculate Cell/Nucleus overlap ratios."""
116
+ cells = load_masks(cell_file)
117
+ nuclei = load_masks(nuc_file)
118
+
119
+ if cells is None or nuclei is None or cells.size == 0 or nuclei.size == 0:
120
+ return {"matched_pairs": 0, "avg_ratio": 0.0, "std_ratio": 0.0}
121
+
122
+ H, W = cells[0].shape
123
+ cell_map = np.zeros((H, W), dtype=int)
124
+ for idx, mask in enumerate(cells):
125
+ cell_map[mask > 0] = idx + 1
126
+
127
+ ratios = []
128
+ for nuc_mask in nuclei:
129
+ props = regionprops(nuc_mask.astype(int))
130
+ if not props:
131
+ continue
132
+ cy, cx = map(int, props[0].centroid)
133
+ if 0 <= cy < H and 0 <= cx < W:
134
+ cell_id = cell_map[cy, cx]
135
+ if cell_id > 0:
136
+ cell_area = np.sum(cells[cell_id - 1])
137
+ nuc_area = np.sum(nuc_mask)
138
+ if nuc_area > 0:
139
+ ratios.append(cell_area / nuc_area)
140
+
141
+ if not ratios:
142
+ return {"matched_pairs": 0, "avg_ratio": 0.0, "std_ratio": 0.0}
143
+
144
+ return {
145
+ "matched_pairs": len(ratios),
146
+ "avg_ratio": float(np.mean(ratios)),
147
+ "std_ratio": float(np.std(ratios))
148
+ }
149
+
150
+
151
+ def save_stats_to_excel(
152
+ base_filename: str,
153
+ cell_stats: Optional[Dict] = None,
154
+ nuc_stats: Optional[Dict] = None,
155
+ spatial_stats: Optional[Dict] = None,
156
+ rel_stats: Optional[Dict] = None
157
+ ) -> str:
158
+ """Write statistics to a multi-sheet Excel file."""
159
+ filename = os.path.splitext(base_filename)[0] + ".xlsx"
160
+
161
+ try:
162
+ workbook = xlsxwriter.Workbook(filename)
163
+ header_fmt = workbook.add_format({'bold': True, 'bg_color': '#D3D3D3', 'border': 1})
164
+ num_fmt = workbook.add_format({'num_format': '0.00'})
165
+
166
+ # Morphology sheet
167
+ ws_morph = workbook.add_worksheet("Morphology")
168
+ headers = ["Structure", "Count", "Mean Area", "StdDev Area", "Unit"]
169
+ ws_morph.write_row(0, 0, headers, header_fmt)
170
+
171
+ row = 1
172
+ if cell_stats and cell_stats.get("count", 0) > 0:
173
+ ws_morph.write(row, 0, "Cells")
174
+ ws_morph.write(row, 1, cell_stats.get('count', 0))
175
+ ws_morph.write(row, 2, cell_stats.get('area_mean', 0), num_fmt)
176
+ ws_morph.write(row, 3, cell_stats.get('area_std', 0), num_fmt)
177
+ ws_morph.write(row, 4, cell_stats.get('unit', 'px²'))
178
+ row += 1
179
+
180
+ if nuc_stats and nuc_stats.get("count", 0) > 0:
181
+ ws_morph.write(row, 0, "Nuclei")
182
+ ws_morph.write(row, 1, nuc_stats.get('count', 0))
183
+ ws_morph.write(row, 2, nuc_stats.get('area_mean', 0), num_fmt)
184
+ ws_morph.write(row, 3, nuc_stats.get('area_std', 0), num_fmt)
185
+ ws_morph.write(row, 4, nuc_stats.get('unit', 'px²'))
186
+
187
+ ws_morph.set_column(0, 4, 15)
188
+
189
+ # Spatial sheet
190
+ if spatial_stats and spatial_stats.get("density", 0) > 0:
191
+ ws_spat = workbook.add_worksheet("Spatial")
192
+ headers = [
193
+ "Structure", "Global Density", "Density Unit",
194
+ "Mean NND", "StdDev NND", "Dist Unit",
195
+ "Mean Neighbors (r=100)", "StdDev Neighbors"
196
+ ]
197
+ ws_spat.write_row(0, 0, headers, header_fmt)
198
+
199
+ ws_spat.write(1, 0, "Cells")
200
+ ws_spat.write(1, 1, spatial_stats.get('density', 0), num_fmt)
201
+ ws_spat.write(1, 2, spatial_stats.get('density_unit', 'N/A'))
202
+ ws_spat.write(1, 3, spatial_stats.get('avg_nnd', 0), num_fmt)
203
+ ws_spat.write(1, 4, spatial_stats.get('std_nnd', 0), num_fmt)
204
+ ws_spat.write(1, 5, spatial_stats.get('dist_unit', 'px'))
205
+ ws_spat.write(1, 6, spatial_stats.get('avg_neighbor_count', 0), num_fmt)
206
+ ws_spat.write(1, 7, spatial_stats.get('std_neighbor_count', 0), num_fmt)
207
+
208
+ ws_spat.set_column(0, 7, 18)
209
+
210
+ # Relational sheet
211
+ if rel_stats and rel_stats.get("matched_pairs", 0) > 0:
212
+ ws_rel = workbook.add_worksheet("Relational")
213
+ headers = ["Relationship", "Matched Pairs", "Mean Area Ratio", "StdDev Ratio"]
214
+ ws_rel.write_row(0, 0, headers, header_fmt)
215
+
216
+ ws_rel.write(1, 0, "Cell_to_Nucleus")
217
+ ws_rel.write(1, 1, rel_stats.get('matched_pairs', 0))
218
+ ws_rel.write(1, 2, rel_stats.get('avg_ratio', 0), num_fmt)
219
+ ws_rel.write(1, 3, rel_stats.get('std_ratio', 0), num_fmt)
220
+
221
+ ws_rel.set_column(0, 3, 20)
222
+
223
+ workbook.close()
224
+ return filename
225
+
226
+ except Exception as e:
227
+ print(f"Error creating Excel file: {e}")
228
+ return f"Error: {e}"
cellemetry/services/sam.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SAM3 segmentation execution.
3
+ Core logic unchanged from original - just updated imports.
4
+ """
5
+ import matplotlib
6
+ matplotlib.use('Agg')
7
+ import matplotlib.pyplot as plt
8
+ import torch
9
+ import torchvision
10
+ import numpy as np
11
+ from PIL import Image
12
+ from skimage.measure import regionprops
13
+
14
+ from ..config.schemas import ComponentRequest
15
+ from ..config.dependencies import AnalysisDeps
16
+
17
+ MIN_SOLIDITY = 0.50
18
+ MIN_CIRCULARITY = 0.1
19
+
20
+ # Use /tmp for all outputs (Cloud Run writable directory)
21
+ OUTPUT_DIR = "/tmp"
22
+
23
+
24
+ def execute_segmentation(deps: AnalysisDeps, request: ComponentRequest) -> str:
25
+ """
26
+ Execute SAM3 segmentation for the given component request.
27
+
28
+ Args:
29
+ deps: Analysis dependencies with SAM model
30
+ request: Component request with color, morphology, entity, bboxes
31
+
32
+ Returns:
33
+ String describing results and output filenames
34
+ """
35
+ text_prompt = f"{request.color} {request.morphology} {request.entity}"
36
+ print(f"\n[Engine] Segmenting: '{text_prompt}' ({len(request.bboxes)} boxes).")
37
+
38
+ # Load Image
39
+ try:
40
+ raw_image = Image.open(deps.image_path).convert("RGB")
41
+ except Exception as e:
42
+ return f"Error loading image: {e}"
43
+
44
+ width, height = raw_image.size
45
+
46
+ # Convert normalized coords (0-1000) to pixel coords
47
+ sam_input_boxes = []
48
+ for box in request.bboxes:
49
+ y_min = (box.ymin / 1000) * height
50
+ x_min = (box.xmin / 1000) * width
51
+ y_max = (box.ymax / 1000) * height
52
+ x_max = (box.xmax / 1000) * width
53
+ sam_input_boxes.append([x_min, y_min, x_max, y_max])
54
+
55
+ if not sam_input_boxes:
56
+ return "No valid boxes provided."
57
+
58
+ # Generate consistent filename from request
59
+ safe_label = f"{request.color}_{request.entity}".replace(" ", "_").lower()
60
+ plot_filename = f"/tmp/out_{safe_label}.png"
61
+ data_filename = f"/tmp/data_{safe_label}.npz"
62
+
63
+ # Check if SAM model is available
64
+ if deps.sam_model is None or deps.sam_processor is None:
65
+ # Return mock result for testing
66
+ return f"[Mock] Would segment '{text_prompt}'. SAM model not loaded. Data file would be: {data_filename}"
67
+
68
+ # Prepare inputs
69
+ sam_input_labels = [[1] * len(sam_input_boxes)]
70
+ input_boxes_batch = [sam_input_boxes]
71
+
72
+ inputs = deps.sam_processor(
73
+ images=raw_image,
74
+ text=text_prompt,
75
+ input_boxes=input_boxes_batch,
76
+ input_boxes_labels=sam_input_labels,
77
+ return_tensors="pt"
78
+ ).to(deps.device)
79
+
80
+ with torch.no_grad():
81
+ outputs = deps.sam_model(**inputs)
82
+
83
+ results = deps.sam_processor.post_process_instance_segmentation(
84
+ outputs,
85
+ threshold=0.3,
86
+ target_sizes=inputs["original_sizes"].tolist()
87
+ )[0]
88
+
89
+ # Morphology filtering
90
+ keep_indices_morph = []
91
+ for i, mask_tensor in enumerate(results["masks"]):
92
+ mask_np = mask_tensor.cpu().numpy()
93
+ mask_np = np.squeeze(mask_np).astype(int)
94
+
95
+ if mask_np.ndim != 2:
96
+ keep_indices_morph.append(False)
97
+ continue
98
+
99
+ props = regionprops(mask_np)
100
+ if not props:
101
+ keep_indices_morph.append(False)
102
+ continue
103
+
104
+ prop = props[0]
105
+ perimeter = prop.perimeter
106
+ circularity = (4 * np.pi * prop.area) / (perimeter ** 2) if perimeter > 0 else 0
107
+
108
+ is_solid = prop.solidity > MIN_SOLIDITY
109
+ is_round_enough = circularity > MIN_CIRCULARITY
110
+ keep_indices_morph.append(is_solid and is_round_enough)
111
+
112
+ if any(keep_indices_morph):
113
+ keep_indices_tensor = torch.tensor(keep_indices_morph, device=results["masks"].device)
114
+ before_count = len(results["masks"])
115
+ results = _filter_results(results, keep_indices_tensor)
116
+ print(f"[Filter] Morphology: Dropped {before_count - len(results['masks'])} debris-like objects.")
117
+
118
+ # NMS
119
+ pred_boxes = results["boxes"]
120
+ pred_scores = results["scores"]
121
+
122
+ if len(pred_scores) > 1:
123
+ keep_indices_nms = torchvision.ops.nms(pred_boxes, pred_scores, iou_threshold=0.3)
124
+ results = _filter_results(results, keep_indices_nms)
125
+ print(f"[NMS] Reduced masks from {len(pred_scores)} to {len(keep_indices_nms)}")
126
+
127
+ # Save outputs
128
+ _save_plot(raw_image, results, sam_input_boxes, text_prompt, plot_filename)
129
+
130
+ mask_count = len(results['masks'])
131
+ if mask_count > 0:
132
+ masks_list = [m.cpu().numpy().squeeze() for m in results['masks']]
133
+ masks_array = np.array(masks_list)
134
+ np.savez_compressed(data_filename, masks=masks_array)
135
+ else:
136
+ np.savez_compressed(data_filename, masks=np.array([]))
137
+
138
+ print(f"[Engine] Saved {mask_count} masks to {data_filename}")
139
+
140
+ # Return with EXACT filename for stats tools to use
141
+ return f"SUCCESS: Found {mask_count} '{text_prompt}' objects. MASK_FILE={data_filename} PLOT_FILE={plot_filename}"
142
+
143
+
144
+ def _filter_results(results, keep_indices):
145
+ """Helper to slice all dictionary keys at once."""
146
+ results["masks"] = results["masks"][keep_indices]
147
+ results["scores"] = results["scores"][keep_indices]
148
+ results["boxes"] = results["boxes"][keep_indices]
149
+ return results
150
+
151
+
152
+ def _save_plot(image, results, boxes, label, filename):
153
+ """Save visualization of segmentation results."""
154
+ fig, ax = plt.subplots(figsize=(10, 10))
155
+ ax.imshow(image)
156
+
157
+ for mask, score in zip(results['masks'], results['scores']):
158
+ if score > 0.3:
159
+ mask_np = mask.cpu().numpy()
160
+ color = np.concatenate([np.random.random(3), np.array([0.5])], axis=0)
161
+ h, w = mask_np.shape[-2:]
162
+ ax.imshow(mask_np.reshape(h, w, 1) * color.reshape(1, 1, -1))
163
+
164
+ ax.set_title(f"{label}")
165
+ ax.axis('off')
166
+ fig.savefig(filename)
167
+ plt.close(fig)
cellemetry/tools/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tools package for bio_agent.
3
+ Exports categorized tool collections.
4
+ """
5
+ from .segmentation import apply_sam3_tool
6
+ from .statistics import get_basic_stats, get_spatial_stats, get_relationship_stats
7
+ from .export import save_excel_tool
8
+
9
+ # All tools used by the analyst agent
10
+ ANALYST_TOOLS = [
11
+ apply_sam3_tool,
12
+ get_basic_stats,
13
+ get_spatial_stats,
14
+ get_relationship_stats,
15
+ save_excel_tool,
16
+ ]
17
+
18
+ __all__ = [
19
+ "ANALYST_TOOLS",
20
+ "apply_sam3_tool",
21
+ "get_basic_stats",
22
+ "get_spatial_stats",
23
+ "get_relationship_stats",
24
+ "save_excel_tool",
25
+ ]
cellemetry/tools/export.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data export tools (Excel, CSV, etc.).
3
+ """
4
+ from typing import Optional
5
+ from google.adk.tools.tool_context import ToolContext
6
+
7
+ from ..services import analysis
8
+
9
+
10
+ def save_excel_tool(
11
+ filename: str,
12
+ cell_stats: Optional[dict] = None,
13
+ nuc_stats: Optional[dict] = None,
14
+ spatial_stats: Optional[dict] = None,
15
+ rel_stats: Optional[dict] = None,
16
+ tool_context: ToolContext = None
17
+ ) -> dict:
18
+ """
19
+ Save all statistics to a multi-sheet Excel file.
20
+
21
+ Args:
22
+ filename: Base filename for the output Excel file
23
+ cell_stats: Optional cell morphology stats
24
+ nuc_stats: Optional nuclei morphology stats
25
+ spatial_stats: Optional spatial distribution stats
26
+ rel_stats: Optional relationship stats
27
+ tool_context: Automatically injected by ADK
28
+
29
+ Returns:
30
+ dict with the output filepath
31
+ """
32
+ result = analysis.save_stats_to_excel(
33
+ filename, cell_stats, nuc_stats, spatial_stats, rel_stats
34
+ )
35
+ return {"excel_path": result}
cellemetry/tools/segmentation.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Segmentation tools using SAM3.
3
+ """
4
+ from typing import Optional, Union, List
5
+ from google.adk.tools.tool_context import ToolContext
6
+
7
+ from ..config.dependencies import get_deps_from_state
8
+ from ..config.schemas import ComponentRequest, BoundingBox
9
+ from ..services import sam
10
+
11
+
12
+ def apply_sam3_tool(
13
+ entity: str,
14
+ color: str,
15
+ morphology: str,
16
+ bboxes: List[BoundingBox],
17
+ tool_context: ToolContext
18
+ ) -> dict:
19
+ """
20
+ Segment biological components using SAM3.
21
+
22
+ Args:
23
+ entity: Object name (e.g., 'cell', 'nucleus') - use SINGULAR form
24
+ color: Color adjective (e.g., 'green', 'blue')
25
+ morphology: Shape adjective (e.g., 'irregular', 'round')
26
+ bboxes: Bounding boxes in one of these formats:
27
+ - List of dicts: [{"ymin": 0, "xmin": 0, "ymax": 100, "xmax": 100}, ...]
28
+ - List of lists: [[ymin, xmin, ymax, xmax], ...]
29
+ Values should be 0-1000 normalized coordinates.
30
+ tool_context: Automatically injected by ADK
31
+
32
+ Returns:
33
+ dict with:
34
+ - result: Description of segmentation outcome
35
+ - mask_file: EXACT path to the .npz file containing masks
36
+ - plot_file: Path to visualization image
37
+ - count: Number of objects found
38
+ """
39
+ deps = get_deps_from_state(tool_context.state)
40
+
41
+ # Convert bboxes to BoundingBox objects, handling multiple formats
42
+ bbox_objects = []
43
+ for b in bboxes:
44
+ if isinstance(b, dict):
45
+ # Format: {"ymin": 0, "xmin": 0, "ymax": 100, "xmax": 100}
46
+ bbox_objects.append(BoundingBox(**b))
47
+ elif isinstance(b, (list, tuple)):
48
+ # Format: [ymin, xmin, ymax, xmax] or [xmin, ymin, xmax, ymax]
49
+ if len(b) == 4:
50
+ # Assume [ymin, xmin, ymax, xmax] based on schema order
51
+ bbox_objects.append(BoundingBox(
52
+ ymin=int(b[0]),
53
+ xmin=int(b[1]),
54
+ ymax=int(b[2]),
55
+ xmax=int(b[3])
56
+ ))
57
+ else:
58
+ print(f"[Warning] Skipping invalid bbox: {b}")
59
+ else:
60
+ print(f"[Warning] Skipping unrecognized bbox format: {b}")
61
+
62
+ if not bbox_objects:
63
+ return {
64
+ "result": "ERROR: No valid bounding boxes provided",
65
+ "mask_file": None,
66
+ "plot_file": None,
67
+ "count": 0,
68
+ "label": f"{color} {morphology} {entity}"
69
+ }
70
+
71
+ request = ComponentRequest(
72
+ entity=entity,
73
+ color=color,
74
+ morphology=morphology,
75
+ bboxes=bbox_objects
76
+ )
77
+
78
+ result_str = sam.execute_segmentation(deps, request)
79
+
80
+ # Generate consistent filenames
81
+ safe_label = f"{color}_{entity}".replace(" ", "_").lower()
82
+ mask_file = f"/tmp/data_{safe_label}.npz"
83
+ plot_file = f"/tmp/out_{safe_label}.png"
84
+
85
+ # Try to extract count from result
86
+ count = 0
87
+ if "Found" in result_str:
88
+ try:
89
+ count = int(result_str.split("Found")[1].split()[0])
90
+ except:
91
+ pass
92
+
93
+ return {
94
+ "result": result_str,
95
+ "mask_file": mask_file, # <-- USE THIS EXACT PATH for get_basic_stats, get_spatial_stats
96
+ "plot_file": plot_file,
97
+ "count": count,
98
+ "label": f"{color} {morphology} {entity}"
99
+ }
cellemetry/tools/statistics.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Statistical analysis tools for morphology and spatial metrics.
3
+ """
4
+ from google.adk.tools.tool_context import ToolContext
5
+
6
+ from ..services import analysis
7
+
8
+
9
+ def get_basic_stats(
10
+ filename: str,
11
+ tool_context: ToolContext
12
+ ) -> dict:
13
+ """
14
+ Calculate basic morphology stats (count, area mean/std).
15
+
16
+ Args:
17
+ filename: Path to the .npz mask file
18
+ tool_context: Automatically injected by ADK
19
+
20
+ Returns:
21
+ dict with count, area_mean, area_std, unit
22
+ """
23
+ pixel_scale = tool_context.state.get("app:pixel_size_microns")
24
+ return analysis.get_basic_stats(filename, pixel_scale=pixel_scale)
25
+
26
+
27
+ def get_spatial_stats(
28
+ filename: str,
29
+ tool_context: ToolContext
30
+ ) -> dict:
31
+ """
32
+ Calculate spatial distribution stats (NND, density, neighbor count).
33
+
34
+ Args:
35
+ filename: Path to the .npz mask file
36
+ tool_context: Automatically injected by ADK
37
+
38
+ Returns:
39
+ dict with spatial metrics
40
+ """
41
+ pixel_scale = tool_context.state.get("app:pixel_size_microns")
42
+ return analysis.get_spatial_stats(filename, pixel_scale=pixel_scale)
43
+
44
+
45
+ def get_relationship_stats(
46
+ cell_file: str,
47
+ nuc_file: str,
48
+ tool_context: ToolContext
49
+ ) -> dict:
50
+ """
51
+ Analyze cell-nucleus relationships (overlap ratios).
52
+
53
+ Args:
54
+ cell_file: Path to cell masks .npz
55
+ nuc_file: Path to nucleus masks .npz
56
+ tool_context: Automatically injected by ADK
57
+
58
+ Returns:
59
+ dict with matched_pairs, avg_ratio, std_ratio
60
+ """
61
+ return analysis.analyze_relationships(cell_file, nuc_file)