Anthony Liang commited on
Commit
7c21061
·
1 Parent(s): 4e80be3

add prediction app and script for running inference on trained model

Browse files
Files changed (6) hide show
  1. .gitignore +10 -0
  2. README.md +46 -1
  3. app.py +338 -0
  4. predict_trace.py +99 -0
  5. requirements.txt +7 -0
  6. trajectory_viz.py +135 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ *.png
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ *.pyw
7
+ *.pyz
8
+ *.pywz
9
+ *.pyzw
10
+ *.pyzwz
README.md CHANGED
@@ -9,4 +9,49 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # Trace Model Visualizer
13
+
14
+ Gradio app for visualizing trace/trajectory predictions from [mihirgrao/trace-model](https://huggingface.co/mihirgrao/trace-model).
15
+
16
+ ## Features
17
+
18
+ - **Image input**: Upload an image
19
+ - **Trace prediction**: Model predicts trajectory points from the image
20
+ - **Visual overlay**: Trace is overlaid on the image with gradient coloring (green start → red end)
21
+ - **Coordinate output**: Predicted trace points are printed below
22
+
23
+ ## Installation
24
+
25
+ ```bash
26
+ pip install -r requirements.txt
27
+ ```
28
+
29
+ ## Usage
30
+
31
+ ### Gradio app
32
+
33
+ ```bash
34
+ python app.py
35
+ ```
36
+
37
+ Then open the URL (default: http://localhost:7860).
38
+
39
+ 1. Click **Load Model** to load the trace model (first run downloads from Hugging Face)
40
+ 2. Upload an image and optionally enter a task instruction (e.g. "Pick up the red block")
41
+ 3. Click **Run Inference**
42
+ 4. View the overlay image and predicted trace points
43
+
44
+ ### CLI script
45
+
46
+ ```bash
47
+ python predict_trace.py image.png
48
+ python predict_trace.py image.png -i "Pick up the red block"
49
+ python predict_trace.py image.png -o output_trace.png -i "Stack the cube on the block"
50
+ python predict_trace.py image.png -o output.png -m mihirgrao/trace-model
51
+ ```
52
+
53
+ - `image` – Path to input image
54
+ - `-i, --instruction` – Task / language instruction (e.g. "Pick up the red block")
55
+ - `-o, --output` – Where to save the overlay (default: `<image>_trace.png`)
56
+ - `-m, --model-id` – Model ID (default: mihirgrao/trace-model)
57
+ - `-p, --prompt` – Full prompt override (if set, ignores `-i`)
app.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio app for Trace Model inference visualization.
4
+
5
+ Takes an image, runs the trace model to predict trajectory points,
6
+ overlays the trace on the image, and displays the predicted coordinates.
7
+
8
+ Model: https://huggingface.co/mihirgrao/trace-model
9
+ """
10
+
11
+ import os
12
+ import tempfile
13
+ import logging
14
+ from typing import Optional, Tuple
15
+
16
+ import gradio as gr
17
+ import numpy as np
18
+ import torch
19
+ from PIL import Image
20
+ from transformers import AutoModelForImageTextToText, AutoProcessor
21
+
22
+ from trajectory_viz import extract_trajectory_from_text, visualize_trajectory_on_image
23
+
24
+ try:
25
+ from qwen_vl_utils import process_vision_info
26
+ except ImportError:
27
+ process_vision_info = None
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Default model path (Hugging Face Hub)
32
+ DEFAULT_MODEL_ID = "mihirgrao/trace-model"
33
+
34
+ # Trace format instruction (always appended)
35
+ TRACE_FORMAT = (
36
+ "Predict the trajectory or trace in this image. "
37
+ "Output the coordinates as a list of [x, y] pairs, e.g. [[0.1, 0.2], [0.3, 0.4], ...]. "
38
+ "Use normalized coordinates between 0 and 1."
39
+ )
40
+
41
+
42
+ def build_prompt(instruction: str = "") -> str:
43
+ """Build the full prompt from task instruction + trace format."""
44
+ if instruction.strip():
45
+ return f"Task: {instruction.strip()}\n\n{TRACE_FORMAT}"
46
+ return TRACE_FORMAT
47
+
48
+ # Global model state (lazy loading)
49
+ _model_state = {
50
+ "model": None,
51
+ "processor": None,
52
+ "model_id": None,
53
+ }
54
+
55
+
56
+ def load_model(model_id: str = DEFAULT_MODEL_ID) -> Tuple[bool, str]:
57
+ """Load the trace model and processor. Returns (success, message)."""
58
+ global _model_state
59
+
60
+ if _model_state["model"] is not None and _model_state["model_id"] == model_id:
61
+ return True, f"Model already loaded: {model_id}"
62
+
63
+ try:
64
+ # Clear previous model
65
+ if _model_state["model"] is not None:
66
+ del _model_state["model"]
67
+ del _model_state["processor"]
68
+ _model_state["model"] = None
69
+ _model_state["processor"] = None
70
+ if torch.cuda.is_available():
71
+ torch.cuda.empty_cache()
72
+
73
+ # Load model with optional flash attention
74
+ load_kwargs = {
75
+ "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else torch.float32,
76
+ "device_map": "auto" if torch.cuda.is_available() else None,
77
+ }
78
+ try:
79
+ if torch.cuda.is_available():
80
+ load_kwargs["attn_implementation"] = "flash_attention_2"
81
+ model = AutoModelForImageTextToText.from_pretrained(model_id, **load_kwargs)
82
+ except (ValueError, ImportError):
83
+ load_kwargs.pop("attn_implementation", None)
84
+ model = AutoModelForImageTextToText.from_pretrained(model_id, **load_kwargs)
85
+ processor = AutoProcessor.from_pretrained(model_id)
86
+
87
+ _model_state["model"] = model
88
+ _model_state["processor"] = processor
89
+ _model_state["model_id"] = model_id
90
+
91
+ return True, f"Model loaded: {model_id}"
92
+ except Exception as e:
93
+ logger.exception("Failed to load model")
94
+ return False, f"Error loading model: {str(e)}"
95
+
96
+
97
+ def run_inference(image_path: str, prompt: str, model_id: str) -> Tuple[str, Optional[str], str]:
98
+ """
99
+ Run trace model inference on an image.
100
+
101
+ Returns:
102
+ (prediction_text, overlay_image_path, trace_points_text)
103
+ """
104
+ success, msg = load_model(model_id)
105
+ if not success:
106
+ return msg, None, ""
107
+
108
+ model = _model_state["model"]
109
+ processor = _model_state["processor"]
110
+
111
+ if image_path is None or not os.path.exists(image_path):
112
+ return "Please provide a valid image.", None, ""
113
+
114
+ try:
115
+ # Ensure file:// format for qwen_vl_utils
116
+ if not image_path.startswith("file://") and not image_path.startswith("http"):
117
+ image_uri = f"file://{os.path.abspath(image_path)}"
118
+ else:
119
+ image_uri = image_path
120
+
121
+ messages = [
122
+ {
123
+ "role": "user",
124
+ "content": [
125
+ {"type": "image", "image": image_uri},
126
+ {"type": "text", "text": prompt},
127
+ ],
128
+ }
129
+ ]
130
+
131
+ # Apply chat template
132
+ text = processor.apply_chat_template(
133
+ messages,
134
+ tokenize=False,
135
+ add_generation_prompt=True,
136
+ )
137
+
138
+ # Process vision info
139
+ if process_vision_info is not None:
140
+ process_kwargs = {"return_video_kwargs": True, "return_video_metadata": True}
141
+ if hasattr(processor, "image_processor") and hasattr(
142
+ processor.image_processor, "patch_size"
143
+ ):
144
+ process_kwargs["image_patch_size"] = processor.image_processor.patch_size
145
+
146
+ image_inputs, video_inputs, video_kwargs = process_vision_info(
147
+ messages, **process_kwargs
148
+ )
149
+ else:
150
+ # Fallback: load image directly and pass to processor
151
+ pil_image = Image.open(image_path).convert("RGB")
152
+ image_inputs = [pil_image]
153
+ video_inputs = None
154
+ video_kwargs = {}
155
+
156
+ # Prepare inputs
157
+ processor_kwargs = {
158
+ "text": [text],
159
+ "images": image_inputs,
160
+ "padding": True,
161
+ "return_tensors": "pt",
162
+ "do_resize": False,
163
+ }
164
+ if video_inputs is not None and len(video_inputs) > 0:
165
+ if isinstance(video_inputs[0], tuple):
166
+ videos, video_metadatas = zip(*video_inputs)
167
+ processor_kwargs["videos"] = list(videos)
168
+ processor_kwargs["video_metadata"] = list(video_metadatas)
169
+ else:
170
+ processor_kwargs["videos"] = video_inputs
171
+ if video_kwargs:
172
+ processor_kwargs.update(video_kwargs)
173
+
174
+ inputs = processor(**processor_kwargs)
175
+ inputs = {k: v.to(model.device) for k, v in inputs.items() if hasattr(v, "to")}
176
+
177
+ # Generate
178
+ with torch.no_grad():
179
+ generated_ids = model.generate(
180
+ **inputs,
181
+ max_new_tokens=1024,
182
+ do_sample=False,
183
+ )
184
+
185
+ # Decode output
186
+ input_ids = inputs["input_ids"]
187
+ generated_ids_trimmed = [
188
+ out[len(inp) :] for inp, out in zip(input_ids, generated_ids)
189
+ ]
190
+ prediction = processor.batch_decode(
191
+ generated_ids_trimmed,
192
+ skip_special_tokens=True,
193
+ clean_up_tokenization_spaces=False,
194
+ )[0]
195
+
196
+ # Extract trajectory and visualize
197
+ trajectories = extract_trajectory_from_text(prediction)
198
+ trace_points_text = format_trace_points(trajectories)
199
+
200
+ overlay_path = None
201
+ if trajectories and len(trajectories) >= 2:
202
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as f:
203
+ overlay_path = f.name
204
+ # Try normalized first (common for VLMs)
205
+ img_arr = visualize_trajectory_on_image(
206
+ trajectory=trajectories,
207
+ image_path=image_path,
208
+ output_path=overlay_path,
209
+ normalized=True,
210
+ )
211
+ if img_arr is None:
212
+ # Fallback: pixel coordinates
213
+ visualize_trajectory_on_image(
214
+ trajectory=trajectories,
215
+ image_path=image_path,
216
+ output_path=overlay_path,
217
+ normalized=False,
218
+ )
219
+
220
+ return prediction, overlay_path, trace_points_text
221
+
222
+ except Exception as e:
223
+ logger.exception("Inference failed")
224
+ return f"Error: {str(e)}", None, ""
225
+
226
+
227
+ def format_trace_points(trajectories) -> str:
228
+ """Format trajectory points for display. trajectories is List[List[float]]."""
229
+ if not trajectories:
230
+ return "No trajectory points extracted."
231
+
232
+ lines = ["## Predicted Trace Points\n"]
233
+ for i, pt in enumerate(trajectories):
234
+ if isinstance(pt, (list, tuple)) and len(pt) >= 2:
235
+ x, y = pt[0], pt[1]
236
+ lines.append(f"- Point {i + 1}: `[{x:.4f}, {y:.4f}]`")
237
+ else:
238
+ lines.append(f"- Point {i + 1}: `{pt}`")
239
+
240
+ return "\n".join(lines)
241
+
242
+
243
+ # --- Gradio UI ---
244
+ try:
245
+ demo = gr.Blocks(title="Trace Model Visualizer", theme=gr.themes.Soft())
246
+ except TypeError:
247
+ demo = gr.Blocks(title="Trace Model Visualizer")
248
+
249
+ with demo:
250
+ gr.Markdown(
251
+ """
252
+ # Trace Model Visualizer
253
+
254
+ Upload an image to predict the trajectory/trace using [mihirgrao/trace-model](https://huggingface.co/mihirgrao/trace-model).
255
+
256
+ The model predicts coordinate points; they are overlaid on the image (green → red gradient) and listed below.
257
+ """
258
+ )
259
+
260
+ with gr.Row():
261
+ with gr.Column(scale=1):
262
+ image_input = gr.Image(
263
+ label="Upload Image",
264
+ type="filepath",
265
+ height=400,
266
+ )
267
+ instruction_input = gr.Textbox(
268
+ label="Task / Language instruction",
269
+ placeholder="e.g. Pick up the red block and place it on the table",
270
+ value="",
271
+ lines=2,
272
+ info="Describe the task. The model will predict the trace for this instruction.",
273
+ )
274
+ model_id_input = gr.Textbox(
275
+ label="Model ID",
276
+ value=DEFAULT_MODEL_ID,
277
+ info="Hugging Face model ID",
278
+ )
279
+ load_model_btn = gr.Button("Load Model", variant="secondary")
280
+ run_btn = gr.Button("Run Inference", variant="primary")
281
+
282
+ with gr.Column(scale=1):
283
+ overlay_output = gr.Image(
284
+ label="Image with Trace Overlay",
285
+ height=400,
286
+ )
287
+ prediction_output = gr.Textbox(
288
+ label="Model Prediction (raw)",
289
+ lines=6,
290
+ )
291
+ trace_points_output = gr.Markdown(
292
+ label="Extracted Trace Points",
293
+ )
294
+
295
+ status_md = gr.Markdown("Click 'Load Model' to load the trace model, then 'Run Inference' on an image.")
296
+
297
+ def on_load_model(model_id: str):
298
+ _, msg = load_model(model_id)
299
+ return f"**Status:** {msg}"
300
+
301
+ def on_run_inference(image_path, instruction, model_id):
302
+ if image_path is None:
303
+ return "Please upload an image first.", None, "", "**Status:** Please upload an image."
304
+ prompt = build_prompt(instruction)
305
+ pred, overlay_path, trace_text = run_inference(image_path, prompt, model_id)
306
+ status = "**Status:** Inference complete." if overlay_path else f"**Status:** {pred}"
307
+ return pred, overlay_path, trace_text, status
308
+
309
+ load_model_btn.click(
310
+ fn=on_load_model,
311
+ inputs=[model_id_input],
312
+ outputs=[status_md],
313
+ )
314
+
315
+ run_btn.click(
316
+ fn=on_run_inference,
317
+ inputs=[image_input, instruction_input, model_id_input],
318
+ outputs=[
319
+ prediction_output,
320
+ overlay_output,
321
+ trace_points_output,
322
+ status_md,
323
+ ],
324
+ api_name="run_inference",
325
+ )
326
+
327
+
328
+ def main():
329
+ """Launch the Gradio app."""
330
+ demo.launch(
331
+ server_name="0.0.0.0",
332
+ server_port=7860,
333
+ share=False,
334
+ )
335
+
336
+
337
+ if __name__ == "__main__":
338
+ main()
predict_trace.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CLI script to predict trace on an image using the trace model.
4
+
5
+ Reuses load_model and run_inference from app.
6
+ """
7
+
8
+ import argparse
9
+ import os
10
+ import shutil
11
+ import sys
12
+
13
+ from app import DEFAULT_MODEL_ID, build_prompt, load_model, run_inference
14
+
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser(
18
+ description="Predict trace/trajectory on an image using mihirgrao/trace-model"
19
+ )
20
+ parser.add_argument("image", type=str, help="Path to input image")
21
+ parser.add_argument(
22
+ "-o",
23
+ "--output",
24
+ type=str,
25
+ default=None,
26
+ help="Path to save overlay image (default: <image>_trace.png)",
27
+ )
28
+ parser.add_argument(
29
+ "-m",
30
+ "--model-id",
31
+ type=str,
32
+ default=DEFAULT_MODEL_ID,
33
+ help=f"Model ID (default: {DEFAULT_MODEL_ID})",
34
+ )
35
+ parser.add_argument(
36
+ "-i",
37
+ "--instruction",
38
+ type=str,
39
+ default="",
40
+ help="Task / language instruction (e.g. 'Pick up the red block')",
41
+ )
42
+ parser.add_argument(
43
+ "-p",
44
+ "--prompt",
45
+ type=str,
46
+ default=None,
47
+ help="Full prompt override (if set, ignores --instruction)",
48
+ )
49
+ args = parser.parse_args()
50
+
51
+ if not os.path.exists(args.image):
52
+ print(f"Error: Image not found: {args.image}", file=sys.stderr)
53
+ sys.exit(1)
54
+
55
+ # Load model
56
+ success, msg = load_model(args.model_id)
57
+ if not success:
58
+ print(f"Error: {msg}", file=sys.stderr)
59
+ sys.exit(1)
60
+ print(f"✓ {msg}")
61
+
62
+ # Build prompt from instruction
63
+ prompt = args.prompt if args.prompt is not None else build_prompt(args.instruction)
64
+
65
+ # Run inference
66
+ prediction, overlay_path, trace_text = run_inference(
67
+ args.image, prompt, args.model_id
68
+ )
69
+
70
+ # Handle errors
71
+ if prediction.startswith("Error:") or prediction.startswith("Please "):
72
+ print(f"Error: {prediction}", file=sys.stderr)
73
+ sys.exit(1)
74
+
75
+ if overlay_path is None:
76
+ print("\nModel prediction (raw):")
77
+ print(prediction)
78
+ print("\n" + trace_text)
79
+ print("\nNo trajectory points were extracted from the prediction.")
80
+ sys.exit(0)
81
+
82
+ # Save overlay to desired path if specified
83
+ output_path = args.output
84
+ if output_path is None:
85
+ base, ext = os.path.splitext(args.image)
86
+ output_path = f"{base}_trace{ext}"
87
+
88
+ shutil.copy(overlay_path, output_path)
89
+ os.unlink(overlay_path) # Remove temp file
90
+ print(f"\n✓ Overlay saved to: {output_path}")
91
+
92
+ print("\nModel prediction (raw):")
93
+ print(prediction)
94
+
95
+ print("\n" + trace_text)
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ transformers>=4.45.0
4
+ accelerate>=0.25.0
5
+ Pillow>=9.0.0
6
+ numpy>=1.20.0
7
+ qwen-vl-utils>=0.0.8
trajectory_viz.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Trajectory Visualization Utilities for Trace Model
3
+
4
+ Extracts trajectory coordinates from model output text and overlays them on images.
5
+ Supports both pixel coordinates and normalized (0-1) coordinates.
6
+ """
7
+
8
+ import os
9
+ import re
10
+ from typing import List, Tuple, Optional, Union
11
+
12
+ import numpy as np
13
+ from PIL import Image, ImageDraw
14
+
15
+
16
+ def extract_trajectory_from_text(text: str) -> List[List[float]]:
17
+ """
18
+ Extract trajectory coordinates from model output text.
19
+
20
+ Handles both pixel coordinates [[100, 200], [150, 250]] and
21
+ normalized coordinates [[0.5, 0.3], [0.7, 0.4]].
22
+
23
+ Args:
24
+ text: The text output from the model containing trajectory information
25
+
26
+ Returns:
27
+ List of [x, y] coordinate pairs as floats
28
+ """
29
+ # Look for coordinate pairs [x, y] - supports ints and floats
30
+ coord_pattern = r"\[\s*(-?\d+(?:\.\d+)?)\s*,\s*(-?\d+(?:\.\d+)?)\s*\]"
31
+ coord_matches = re.findall(coord_pattern, text)
32
+
33
+ if not coord_matches:
34
+ return []
35
+
36
+ trajectory = []
37
+ for x_str, y_str in coord_matches:
38
+ try:
39
+ x = float(x_str.strip())
40
+ y = float(y_str.strip())
41
+ trajectory.append([x, y])
42
+ except (ValueError, IndexError):
43
+ continue
44
+
45
+ return trajectory
46
+
47
+
48
+ def _to_pixel_coords(
49
+ trajectory: List[List[float]],
50
+ img_width: int,
51
+ img_height: int,
52
+ normalized: bool = True,
53
+ ) -> List[List[int]]:
54
+ """Convert trajectory to pixel coordinates."""
55
+ pixel_traj = []
56
+ for x, y in trajectory:
57
+ if normalized:
58
+ x = int(x * img_width)
59
+ y = int(y * img_height)
60
+ else:
61
+ x, y = int(x), int(y)
62
+ pixel_traj.append([x, y])
63
+ return pixel_traj
64
+
65
+
66
+ def visualize_trajectory_on_image(
67
+ trajectory: List[List[float]],
68
+ image_path: Optional[str] = None,
69
+ output_path: Optional[str] = None,
70
+ pil_image: Optional[Image.Image] = None,
71
+ normalized: bool = True,
72
+ start_color: Tuple[int, int, int] = (0, 255, 0),
73
+ end_color: Tuple[int, int, int] = (255, 0, 0),
74
+ line_thickness: int = 4,
75
+ ) -> Optional[np.ndarray]:
76
+ """
77
+ Overlay trajectory on an image with gradient coloring (green start -> red end).
78
+
79
+ Args:
80
+ trajectory: List of [x, y] coordinate pairs (pixel or normalized)
81
+ image_path: Path to input image (used if pil_image is None)
82
+ output_path: Where to save the output image
83
+ pil_image: PIL Image to draw on (overrides image_path)
84
+ normalized: If True, coordinates are 0-1 and will be scaled to image size
85
+ start_color: RGB for trajectory start
86
+ end_color: RGB for trajectory end
87
+ line_thickness: Line width in pixels
88
+
89
+ Returns:
90
+ numpy array of the output image, or None if trajectory too short
91
+ """
92
+ if not trajectory or len(trajectory) < 2:
93
+ return None
94
+
95
+ if pil_image is not None:
96
+ img = pil_image.convert("RGB").copy()
97
+ elif image_path and os.path.exists(image_path):
98
+ img = Image.open(image_path).convert("RGB").copy()
99
+ else:
100
+ return None
101
+
102
+ w, h = img.size
103
+ pixel_traj = _to_pixel_coords(trajectory, w, h, normalized=normalized)
104
+
105
+ # Clamp to image bounds
106
+ pixel_traj = [
107
+ [max(0, min(w - 1, x)), max(0, min(h - 1, y))]
108
+ for x, y in pixel_traj
109
+ ]
110
+
111
+ draw = ImageDraw.Draw(img)
112
+
113
+ # Draw gradient line segments
114
+ num_segments = len(pixel_traj) - 1
115
+ for i in range(num_segments):
116
+ progress = i / max(1, num_segments - 1)
117
+ r = int(start_color[0] * (1 - progress) + end_color[0] * progress)
118
+ g = int(start_color[1] * (1 - progress) + end_color[1] * progress)
119
+ b = int(start_color[2] * (1 - progress) + end_color[2] * progress)
120
+ segment_color = (r, g, b)
121
+ start_pt = tuple(pixel_traj[i])
122
+ end_pt = tuple(pixel_traj[i + 1])
123
+ draw.line([start_pt, end_pt], fill=segment_color, width=line_thickness)
124
+
125
+ # Draw start marker
126
+ if pixel_traj:
127
+ sx, sy = pixel_traj[0]
128
+ r = max(3, line_thickness)
129
+ bbox = [sx - r, sy - r, sx + r, sy + r]
130
+ draw.ellipse(bbox, fill=start_color, outline=(255, 255, 255), width=2)
131
+
132
+ if output_path:
133
+ img.save(output_path)
134
+
135
+ return np.array(img)