File size: 7,429 Bytes
4f24301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import json
from json import JSONDecoder
import re
from typing import Dict, Any, Optional, Union, List
import numpy as np
from PIL import Image

from deepforest_agent.tools.deepforest_tool import DeepForestPredictor
from deepforest_agent.utils.state_manager import session_state_manager
from deepforest_agent.utils.image_utils import validate_image_path
from deepforest_agent.conf.config import Config

deepforest_predictor = DeepForestPredictor()

def run_deepforest_object_detection(
    session_id: str,
    model_names: List[str] = ["tree", "bird", "livestock"],
    patch_size: int = Config.DEEPFOREST_DEFAULTS["patch_size"],
    patch_overlap: float = Config.DEEPFOREST_DEFAULTS["patch_overlap"],
    iou_threshold: float = Config.DEEPFOREST_DEFAULTS["iou_threshold"],
    thresh: float = Config.DEEPFOREST_DEFAULTS["thresh"],
    alive_dead_trees: bool = Config.DEEPFOREST_DEFAULTS["alive_dead_trees"]
) -> Dict[str, Any]:
    """
    Run DeepForest object detection on the globally stored image.

    Args:
        session_id (str): Unique session identifier for this user
        model_names: List of model names to use ("tree", "bird", "livestock")
        patch_size: Patch size for each window in pixels (not geographic units). The size for the crops used to cut the input image/raster into smaller pieces.
        patch_overlap: Patch overlap among windows. The horizontal and vertical overlap among patches (must be between 0-1).
        iou_threshold: Minimum IoU overlap among predictions between windows to be suppressed.
        thresh: Score threshold used to filter bboxes after soft-NMS is performed.
        alive_dead_trees: Whether to classify trees as alive/dead

    Returns:
        Dictionary with detection_summary and detections_list
    """
    # Validate session exists
    if not session_state_manager.session_exists(session_id):
        return {
            "detection_summary": f"Session {session_id} not found.",
            "detections_list": [],
            "status": "error"
        }

    image_file_path = session_state_manager.get(session_id, "image_file_path")
    current_image = session_state_manager.get(session_id, "current_image")
    
    if image_file_path is None and current_image is None:
        return {
            "detection_summary": f"No image available for detection in session {session_id}.",
            "detections_list": [],
            "status": "error"
        }

    if image_file_path and not validate_image_path(image_file_path):
        print(f"Warning: Invalid image file path {image_file_path}, falling back to PIL image")
        image_file_path = None
    
    try:
        if image_file_path:
            print(f"DeepForest: Processing image from file path: {image_file_path}")
            detection_summary, annotated_image, detections_list = deepforest_predictor.predict_objects(
                image_file_path=image_file_path,
                model_names=model_names,
                patch_size=patch_size,
                patch_overlap=patch_overlap,
                iou_threshold=iou_threshold,
                thresh=thresh,
                alive_dead_trees=alive_dead_trees
            )
        else:
            print(f"DeepForest: Processing PIL image (size: {current_image.size})")
            image_array = np.array(current_image)
            detection_summary, annotated_image, detections_list = deepforest_predictor.predict_objects(
                image_data_array=image_array,
                model_names=model_names,
                patch_size=patch_size,
                patch_overlap=patch_overlap,
                iou_threshold=iou_threshold,
                thresh=thresh,
                alive_dead_trees=alive_dead_trees
            )
        
        if annotated_image is not None:
            session_state_manager.set(session_id, "annotated_image", Image.fromarray(annotated_image))
        
        result = {
            "detection_summary": detection_summary,
            "detections_list": detections_list,
            "total_detections": len(detections_list),
            "status": "success"
        }
        
        return result
        
    except Exception as e:
        error_msg = f"Error during image detection in session {session_id}: {str(e)}"
        print(f"DeepForest Detection Error: {error_msg}")
        return {
            "detection_summary": error_msg,
            "detections_list": [],
            "total_detections": 0,
            "status": "error"
        }

def extract_all_tool_calls(text: str) -> List[Dict[str, Any]]:
    """
    Extract all tool call information from model output text.
    
    Args:
        text: The model's output text that may contain multiple tool calls
        
    Returns:
        List of dictionaries with tool call info (empty list if none found)
    """
    tool_calls = []

    # Method 1: Wrapped in XML
    xml_pattern = r'<tool_call>\s*(\{.*?\})\s*</tool_call>'
    xml_matches = re.findall(xml_pattern, text, re.DOTALL)
    
    for match in xml_matches:
        try:
            result = json.loads(match.strip())
            if isinstance(result, dict) and "name" in result and "arguments" in result:
                print(f"Found valid XML tool call: {result}")
                tool_calls.append(result)
        except json.JSONDecodeError as e:
            print(f"Failed to parse XML tool call JSON: {e}")
            continue
    
    # Method 2: If no XML format found, try raw JSON format
    if not tool_calls:
        decoder = JSONDecoder()
        brace_start = 0
        
        while True:
            match = text.find('{', brace_start)
            if match == -1:
                break
            try:
                result, index = decoder.raw_decode(text[match:])
                if isinstance(result, dict) and "name" in result and "arguments" in result:
                    print(f"Found valid raw JSON tool call: {result}")
                    tool_calls.append(result)
                    brace_start = match + index
                else:
                    brace_start = match + 1
            except ValueError:
                brace_start = match + 1
    
    print(f"Total tool calls extracted: {len(tool_calls)}")
    return tool_calls

def handle_tool_call(tool_name: str, tool_arguments: Dict[str, Any], session_id: str) -> Union[str, Dict[str, Any]]:
    """
    Handle tool call execution from tool name and arguments.
    
    Args:
        tool_name (str): The name of the tool to be executed.
        tool_arguments (Dict[str, Any]): A dictionary of arguments for the tool.
        session_id: Unique session identifier for this user
        
    Returns:
        Either error message (str) or tool execution result (dict)
    """
    print(f"Tool Call Detected:")
    print(f"Tool Name: {tool_name}")
    print(f"Arguments: {tool_arguments}")
    
    if tool_name == "run_deepforest_object_detection":
        try:
            result = run_deepforest_object_detection(session_id=session_id, **tool_arguments)
            
            return result
            
        except Exception as e:
            error_msg = f"Error executing {tool_name} in session {session_id}: {str(e)}"
            print(f"Tool Execution Failed: {error_msg}")
            
            return error_msg
    else:
        error_msg = f"Unknown tool: {tool_name}"
        print(f"Unknown Tool: {error_msg}")
        
        return error_msg