JK-Ying-Long commited on
Commit
2c8aff3
·
verified ·
1 Parent(s): 4b32129

initial commit

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. 8.gif +3 -0
  3. DroneStalker.py +69 -0
  4. Figure_1.png +0 -0
  5. README.md +123 -3
  6. dronestalker-1.1.pth +3 -0
  7. requirements.txt +1 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 8.gif filter=lfs diff=lfs merge=lfs -text
8.gif ADDED

Git LFS Details

  • SHA256: 001fc22656956841b530381611c449f710d0666f0ec00ad4f66043262740b67b
  • Pointer size: 132 Bytes
  • Size of remote file: 2.47 MB
DroneStalker.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class DroneStalker(nn.Module):
5
+ INTERVAL = 0.033333 # Seconds
6
+ IMAGE_WIDTH = 1280
7
+ IMAGE_HEIGHT = 720
8
+
9
+ def __init__(self, Np: int, Nf: int):
10
+ super().__init__()
11
+ self.Np = Np
12
+ self.Nf = Nf
13
+
14
+ def _extract_features(self, sample):
15
+ features = []
16
+ for i, box in enumerate(sample):
17
+ if i == 0:
18
+ features.append(self._get_kinematics(box, box))
19
+ continue
20
+ past_box = sample[i - 1]
21
+ features.append(self._get_kinematics(past_box, box))
22
+ return features
23
+
24
+ def _get_kinematics(self, past_box, box):
25
+ past_x1, past_y1, past_x2, past_y2 = (past_box[0] / self.IMAGE_WIDTH, past_box[1] / self.IMAGE_HEIGHT, past_box[2] / self.IMAGE_WIDTH, past_box[3] / self.IMAGE_HEIGHT)
26
+ x1, y1, x2, y2 = (box[0] / self.IMAGE_WIDTH, box[1] / self.IMAGE_HEIGHT, box[2] / self.IMAGE_WIDTH, box[3] / self.IMAGE_HEIGHT)
27
+ x_center = (x1 + x2) / 2
28
+ y_center = (y1 + y2) / 2
29
+ past_x_center = (past_x1 + past_x2) / 2
30
+ past_y_center = (past_y1 + past_y2) / 2
31
+ x_velocity = (x_center - past_x_center) / (self.INTERVAL)
32
+ y_velocity = (y_center - past_y_center) / (self.INTERVAL)
33
+ return [x_center, y_center, x_velocity, y_velocity]
34
+
35
+ class DroneStalkerBase(DroneStalker):
36
+ def __init__(self, Np: int, Nf: int):
37
+ super().__init__(Np, Nf)
38
+
39
+ def _get_kinematics(self, past_box, box):
40
+ [x_center, y_center, x_velocity, y_velocity] = super()._get_kinematics(past_box, box)
41
+ x1, y1, x2, y2 = (box[0] / self.IMAGE_WIDTH, box[1] / self.IMAGE_HEIGHT, box[2] / self.IMAGE_WIDTH, box[3] / self.IMAGE_HEIGHT)
42
+ width = x2 - x1
43
+ height = y2 - y1
44
+ return [x_center, y_center, x_velocity, y_velocity, width, height, x1, y1]
45
+
46
+ class Model(DroneStalkerBase):
47
+ def __init__(self, Np: int, Nf: int, hidden_dim: int = 128, num_layers: int = 2, dropout: float = 0.1):
48
+ super().__init__(Np, Nf)
49
+ # Input layer
50
+ self.input = nn.Linear(8, 16)
51
+ self.leaky_relu = nn.LeakyReLU()
52
+ self.hidden = nn.GRU(input_size=16, hidden_size=hidden_dim, num_layers=num_layers, dropout=dropout, batch_first=True)
53
+ self.output = nn.Linear(hidden_dim, Nf * 4)
54
+
55
+ def forward(self, batch):
56
+ batch_size = batch.shape[0]
57
+
58
+ # Extract features
59
+ features = []
60
+ for sample in batch:
61
+ features.append(self._extract_features(sample))
62
+ x = torch.tensor(features, dtype=torch.float32)
63
+
64
+ # Forward pass
65
+ out = self.input(x)
66
+ out = self.leaky_relu(out)
67
+ out, _ = self.hidden(out)
68
+ out = self.output(out[:, -1, :])
69
+ return out.view(batch_size, self.Nf, 4)
Figure_1.png ADDED
README.md CHANGED
@@ -1,3 +1,123 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - trajectory-prediction
4
+ - lstm
5
+ - drone-tracking
6
+ - computer-vision
7
+ license: apache-2.0
8
+ datasets:
9
+ - Ecoaetix/uFRED-predict-0.4
10
+ ---
11
+
12
+ # Drone Stalker 1
13
+
14
+ ![Demo GIF](8.gif)
15
+
16
+ LSTM model for predicting drone trajectories based on bounding box sequences from video footage.
17
+
18
+ ## Model Description
19
+
20
+ This model predicts future drone positions given past trajectory data. It processes sequences of bounding boxes and outputs predicted future positions, significantly outperforming baseline models on the FRED dataset.
21
+
22
+ Drone Stalker 1 is an extremely lightweight model with just 2,224 parameters. Despite this, its performance is on par with other models of up to 300k parameters.
23
+
24
+ ## Architecture
25
+
26
+ - **Model Type**: GRU (Long Short-Term Memory)
27
+ - **Input Features**: [x_center, y_center, x_velocity, y_velocity, width, height, x1, y1]
28
+ - **Total Parameters**: 2,592
29
+ - **Input Sequence Length**: 12 frames (Np=12)
30
+ - **Output Sequence Length**: 12 frames (Nf=12)
31
+ - **Frame Interval**: 33.3ms (30 FPS)
32
+ - **Image Resolution**: 1280x720
33
+
34
+ ### Output
35
+
36
+ Predicted future bounding boxes (normalized [0, 1])
37
+
38
+ ## Training Details
39
+
40
+ - **Dataset**: uFRED-predict-0.4
41
+ - **Epochs**: 25
42
+ - **Learning Rate**: 1e-3
43
+ - **Optimizer**: Adam
44
+ - **Loss Function**: Smooth L1 Loss
45
+
46
+ ## Performance
47
+
48
+ Evaluation metrics on test set:
49
+
50
+ - **Average Displacement Error (ADE)**: 23.91px
51
+ - **Final Displacement Error (FDE)**: 43.83px
52
+ - **Mean Intersection over Union (mIoU)***: 0.5135
53
+
54
+ ![Performance Comparison Chart](Figure_1.png)
55
+
56
+ ## Usage
57
+
58
+ ```python
59
+ import torch
60
+
61
+ # Load the model
62
+ model = torch.hub.load_state_dict_from_url(
63
+ 'https://huggingface.co/Ecoaetix/DroneStalker/resolve/main/dronestalker-1.1.pth'
64
+ )
65
+
66
+ # Or download and load manually
67
+ from huggingface_hub import hf_hub_download
68
+
69
+ model_path = hf_hub_download(
70
+ repo_id="Ecoaetix/DroneStalker",
71
+ filename="dronestalker-1.1.pth"
72
+ )
73
+
74
+ # You'll need the Model class (included as model.py in this repo)
75
+ from DroneStalker import Model
76
+
77
+ model = Model(Np=12, Nf=12, hidden_dim=16, num_layers=1, dropout=0)
78
+ model.load_state_dict(torch.load(model_path))
79
+ model.eval()
80
+
81
+ # Inference
82
+ with torch.no_grad():
83
+ # Input: [batch_size, 12, 4] - 12 past bounding boxes [x1, y1, x2, y2]
84
+ predictions = model(past_bboxes)
85
+ # Output: [batch_size, 12, 4] - 12 future bounding boxes (min-max normalized)
86
+ ```
87
+
88
+ ## Input Format
89
+
90
+ The model expects input bounding boxes in pixel coordinates:
91
+ - Shape: `[batch_size, 12, 4]`
92
+ - Format: `[x1, y1, x2, y2]` where (x1,y1) is top-left, (x2,y2) is bottom-right
93
+ - Image dimensions: 1280x720 pixels
94
+
95
+ ## Output Format
96
+
97
+ The model outputs normalized predictions:
98
+ - Shape: `[batch_size, 12, 4]`
99
+ - Format: `[x1_norm, y1_norm, x2_norm, y2_norm]` where values are in range [0, 1]
100
+ - Multiply x-coordinates by 1280 and y-coordinates by 720 to get pixel values
101
+
102
+ ## Limitations
103
+
104
+ - Trained specifically on drone footage at 1280x720 resolution
105
+ - Assumes consistent frame rate of 30 FPS
106
+ - Best performance on stationary, ground-based tracking scenarios similar to training data
107
+ - Single object tracking only
108
+
109
+ ## Citation
110
+
111
+ ```bibtex
112
+ @misc{DroneStalker-LSTM-0.3,
113
+ author = {Jacob Kenney},
114
+ title = {DroneStalker-LSTM-0.3},
115
+ year = {2025},
116
+ publisher = {HuggingFace},
117
+ howpublished = {\url{https://huggingface.co/Ecoaetix/DroneStalker}}
118
+ }
119
+ ```
120
+
121
+ ## License
122
+
123
+ Apache 2.0
dronestalker-1.1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f514dea60dc7c55b0ee8b7eecc1833329d71338091f1222bf04dec0b3999e55f
3
+ size 13676
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ torch>=2.0.0