Spaces:
No application file
No application file
File size: 7,429 Bytes
4f24301 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import json
from json import JSONDecoder
import re
from typing import Dict, Any, Optional, Union, List
import numpy as np
from PIL import Image
from deepforest_agent.tools.deepforest_tool import DeepForestPredictor
from deepforest_agent.utils.state_manager import session_state_manager
from deepforest_agent.utils.image_utils import validate_image_path
from deepforest_agent.conf.config import Config
deepforest_predictor = DeepForestPredictor()
def run_deepforest_object_detection(
session_id: str,
model_names: List[str] = ["tree", "bird", "livestock"],
patch_size: int = Config.DEEPFOREST_DEFAULTS["patch_size"],
patch_overlap: float = Config.DEEPFOREST_DEFAULTS["patch_overlap"],
iou_threshold: float = Config.DEEPFOREST_DEFAULTS["iou_threshold"],
thresh: float = Config.DEEPFOREST_DEFAULTS["thresh"],
alive_dead_trees: bool = Config.DEEPFOREST_DEFAULTS["alive_dead_trees"]
) -> Dict[str, Any]:
"""
Run DeepForest object detection on the globally stored image.
Args:
session_id (str): Unique session identifier for this user
model_names: List of model names to use ("tree", "bird", "livestock")
patch_size: Patch size for each window in pixels (not geographic units). The size for the crops used to cut the input image/raster into smaller pieces.
patch_overlap: Patch overlap among windows. The horizontal and vertical overlap among patches (must be between 0-1).
iou_threshold: Minimum IoU overlap among predictions between windows to be suppressed.
thresh: Score threshold used to filter bboxes after soft-NMS is performed.
alive_dead_trees: Whether to classify trees as alive/dead
Returns:
Dictionary with detection_summary and detections_list
"""
# Validate session exists
if not session_state_manager.session_exists(session_id):
return {
"detection_summary": f"Session {session_id} not found.",
"detections_list": [],
"status": "error"
}
image_file_path = session_state_manager.get(session_id, "image_file_path")
current_image = session_state_manager.get(session_id, "current_image")
if image_file_path is None and current_image is None:
return {
"detection_summary": f"No image available for detection in session {session_id}.",
"detections_list": [],
"status": "error"
}
if image_file_path and not validate_image_path(image_file_path):
print(f"Warning: Invalid image file path {image_file_path}, falling back to PIL image")
image_file_path = None
try:
if image_file_path:
print(f"DeepForest: Processing image from file path: {image_file_path}")
detection_summary, annotated_image, detections_list = deepforest_predictor.predict_objects(
image_file_path=image_file_path,
model_names=model_names,
patch_size=patch_size,
patch_overlap=patch_overlap,
iou_threshold=iou_threshold,
thresh=thresh,
alive_dead_trees=alive_dead_trees
)
else:
print(f"DeepForest: Processing PIL image (size: {current_image.size})")
image_array = np.array(current_image)
detection_summary, annotated_image, detections_list = deepforest_predictor.predict_objects(
image_data_array=image_array,
model_names=model_names,
patch_size=patch_size,
patch_overlap=patch_overlap,
iou_threshold=iou_threshold,
thresh=thresh,
alive_dead_trees=alive_dead_trees
)
if annotated_image is not None:
session_state_manager.set(session_id, "annotated_image", Image.fromarray(annotated_image))
result = {
"detection_summary": detection_summary,
"detections_list": detections_list,
"total_detections": len(detections_list),
"status": "success"
}
return result
except Exception as e:
error_msg = f"Error during image detection in session {session_id}: {str(e)}"
print(f"DeepForest Detection Error: {error_msg}")
return {
"detection_summary": error_msg,
"detections_list": [],
"total_detections": 0,
"status": "error"
}
def extract_all_tool_calls(text: str) -> List[Dict[str, Any]]:
"""
Extract all tool call information from model output text.
Args:
text: The model's output text that may contain multiple tool calls
Returns:
List of dictionaries with tool call info (empty list if none found)
"""
tool_calls = []
# Method 1: Wrapped in XML
xml_pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
xml_matches = re.findall(xml_pattern, text, re.DOTALL)
for match in xml_matches:
try:
result = json.loads(match.strip())
if isinstance(result, dict) and "name" in result and "arguments" in result:
print(f"Found valid XML tool call: {result}")
tool_calls.append(result)
except json.JSONDecodeError as e:
print(f"Failed to parse XML tool call JSON: {e}")
continue
# Method 2: If no XML format found, try raw JSON format
if not tool_calls:
decoder = JSONDecoder()
brace_start = 0
while True:
match = text.find('{', brace_start)
if match == -1:
break
try:
result, index = decoder.raw_decode(text[match:])
if isinstance(result, dict) and "name" in result and "arguments" in result:
print(f"Found valid raw JSON tool call: {result}")
tool_calls.append(result)
brace_start = match + index
else:
brace_start = match + 1
except ValueError:
brace_start = match + 1
print(f"Total tool calls extracted: {len(tool_calls)}")
return tool_calls
def handle_tool_call(tool_name: str, tool_arguments: Dict[str, Any], session_id: str) -> Union[str, Dict[str, Any]]:
"""
Handle tool call execution from tool name and arguments.
Args:
tool_name (str): The name of the tool to be executed.
tool_arguments (Dict[str, Any]): A dictionary of arguments for the tool.
session_id: Unique session identifier for this user
Returns:
Either error message (str) or tool execution result (dict)
"""
print(f"Tool Call Detected:")
print(f"Tool Name: {tool_name}")
print(f"Arguments: {tool_arguments}")
if tool_name == "run_deepforest_object_detection":
try:
result = run_deepforest_object_detection(session_id=session_id, **tool_arguments)
return result
except Exception as e:
error_msg = f"Error executing {tool_name} in session {session_id}: {str(e)}"
print(f"Tool Execution Failed: {error_msg}")
return error_msg
else:
error_msg = f"Unknown tool: {tool_name}"
print(f"Unknown Tool: {error_msg}")
return error_msg
|