Spaces:
Runtime error
Runtime error
Upload 9 files
Browse files- Fight_detec_func.py +103 -0
- README.md +187 -12
- app.py +35 -0
- frame_slicer.py +58 -0
- full_project.py +22 -0
- model_summary.py +10 -0
- objec_detect_yolo.py +121 -0
- requirements.txt +6 -0
- trainig.py +248 -0
Fight_detec_func.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
from frame_slicer import extract_video_frames
|
| 3 |
+
import cv2
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
|
| 8 |
+
# Configuration
|
| 9 |
+
import os
|
| 10 |
+
MODEL_PATH = os.path.join(os.path.dirname(__file__), "trainnig_output", "final_model_2.h5")
|
| 11 |
+
N_FRAMES = 30
|
| 12 |
+
IMG_SIZE = (96, 96)
|
| 13 |
+
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results") # Will be created if doesn't exist
|
| 14 |
+
|
| 15 |
+
def fight_detec(video_path: str, debug: bool = True):
|
| 16 |
+
"""Detects fight in a video and returns the result and confidence score."""
|
| 17 |
+
|
| 18 |
+
class FightDetector:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.model = self._load_model()
|
| 21 |
+
|
| 22 |
+
def _load_model(self):
|
| 23 |
+
try:
|
| 24 |
+
model = tf.keras.models.load_model(MODEL_PATH, compile=False)
|
| 25 |
+
if debug:
|
| 26 |
+
print("\nModel loaded successfully. Input shape:", model.input_shape)
|
| 27 |
+
return model
|
| 28 |
+
except Exception as e:
|
| 29 |
+
print(f"Model loading failed: {e}")
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
def _extract_frames(self, video_path):
|
| 33 |
+
frames = extract_video_frames(video_path, N_FRAMES, IMG_SIZE)
|
| 34 |
+
if frames is None:
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
if debug:
|
| 38 |
+
blank_frames = np.all(frames == 0, axis=(1, 2, 3)).sum()
|
| 39 |
+
if blank_frames > 0:
|
| 40 |
+
print(f"Warning: {blank_frames} blank frames detected")
|
| 41 |
+
sample_frame = (frames[0] * 255).astype(np.uint8)
|
| 42 |
+
os.makedirs(RESULT_PATH, exist_ok=True)
|
| 43 |
+
cv2.imwrite(os.path.join(RESULT_PATH, 'debug_frame.jpg'),
|
| 44 |
+
cv2.cvtColor(sample_frame, cv2.COLOR_RGB2BGR))
|
| 45 |
+
|
| 46 |
+
return frames
|
| 47 |
+
|
| 48 |
+
def predict(self, video_path):
|
| 49 |
+
if not os.path.exists(video_path):
|
| 50 |
+
return "Error: Video not found", None
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
frames = self._extract_frames(video_path)
|
| 54 |
+
if frames is None:
|
| 55 |
+
return "Error: Frame extraction failed", None
|
| 56 |
+
|
| 57 |
+
if frames.shape[0] != N_FRAMES:
|
| 58 |
+
return f"Error: Expected {N_FRAMES} frames, got {frames.shape[0]}", None
|
| 59 |
+
|
| 60 |
+
if np.all(frames == 0):
|
| 61 |
+
return "Error: All frames are blank", None
|
| 62 |
+
|
| 63 |
+
prediction = self.model.predict(frames[np.newaxis, ...], verbose=0)[0][0]
|
| 64 |
+
result = "FIGHT" if prediction >= 0.61 else "NORMAL"
|
| 65 |
+
confidence = min(max(abs(prediction - 0.61) * 150 + 50, 0), 100)
|
| 66 |
+
|
| 67 |
+
if debug:
|
| 68 |
+
self._debug_visualization(frames, prediction, result, video_path)
|
| 69 |
+
|
| 70 |
+
return f"{result} ({confidence:.1f}% confidence)", prediction
|
| 71 |
+
|
| 72 |
+
except Exception as e:
|
| 73 |
+
return f"Prediction error: {str(e)}", None
|
| 74 |
+
|
| 75 |
+
def _debug_visualization(self, frames, score, result, video_path):
|
| 76 |
+
print(f"\nPrediction Score: {score:.4f}")
|
| 77 |
+
print(f"Decision: {result}")
|
| 78 |
+
plt.figure(figsize=(15, 5))
|
| 79 |
+
for i in range(min(10, len(frames))):
|
| 80 |
+
plt.subplot(2, 5, i+1)
|
| 81 |
+
plt.imshow(frames[i])
|
| 82 |
+
plt.title(f"Frame {i}\nMean: {frames[i].mean():.2f}")
|
| 83 |
+
plt.axis('off')
|
| 84 |
+
plt.suptitle(f"Prediction: {result} (Score: {score:.4f})")
|
| 85 |
+
plt.tight_layout()
|
| 86 |
+
|
| 87 |
+
# Save the visualization
|
| 88 |
+
base_name = os.path.splitext(os.path.basename(video_path))[0]
|
| 89 |
+
save_path = os.path.join(RESULT_PATH, f"{base_name}_prediction_result.png")
|
| 90 |
+
plt.savefig(save_path)
|
| 91 |
+
plt.close()
|
| 92 |
+
print(f"Visualization saved to: {save_path}")
|
| 93 |
+
|
| 94 |
+
detector = FightDetector()
|
| 95 |
+
if detector.model is None:
|
| 96 |
+
return "Error: Model loading failed", None
|
| 97 |
+
return detector.predict(video_path)
|
| 98 |
+
|
| 99 |
+
# # Entry point
|
| 100 |
+
# path0 = input("Enter the local path to the video file to detect fight: ")
|
| 101 |
+
# path = path0.strip('"') # Remove extra quotes if copied from Windows
|
| 102 |
+
# print(f"[INFO] Loading video: {path}")
|
| 103 |
+
# fight_detec(path)
|
README.md
CHANGED
|
@@ -1,12 +1,187 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Video Analysis Project: Fight and Object Detection
|
| 2 |
+
|
| 3 |
+
## 1. Overview
|
| 4 |
+
|
| 5 |
+
This project analyzes video files to perform two main tasks:
|
| 6 |
+
* **Fight Detection:** Classifies video segments as containing a "FIGHT" or being "NORMAL" using a custom-trained 3D Convolutional Neural Network (CNN).
|
| 7 |
+
* **Object Detection:** Identifies and tracks specific objects within the video using a pre-trained YOLOv8 model.
|
| 8 |
+
|
| 9 |
+
The system processes an input video and outputs the fight classification result along with an annotated version of the video highlighting detected objects.
|
| 10 |
+
|
| 11 |
+
## 2. Features
|
| 12 |
+
|
| 13 |
+
* Dual analysis: Combines action recognition (fight detection) and object detection.
|
| 14 |
+
* Custom-trained model for fight detection tailored to specific data.
|
| 15 |
+
* Utilizes state-of-the-art YOLOv8 for object detection.
|
| 16 |
+
* Generates an annotated output video showing detected objects and their tracks.
|
| 17 |
+
* Provides confidence scores for fight detection results.
|
| 18 |
+
* Includes scripts for both inference (`full_project.py`) and training (`trainig.py`) the fight detection model.
|
| 19 |
+
|
| 20 |
+
## 3. Project Structure
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
ComV/
|
| 24 |
+
├── [Project Directory]/ # e.g., AI_made
|
| 25 |
+
│ ├── full_project.py # Main script for running inference
|
| 26 |
+
│ ├── Fight_detec_func.py # Fight detection logic and model loading
|
| 27 |
+
│ ├── objec_detect_yolo.py # Object detection logic using YOLOv8
|
| 28 |
+
│ ├── frame_slicer.py # Utility for extracting frames for fight detection
|
| 29 |
+
│ ├── trainig.py # Script for training the fight detection model
|
| 30 |
+
│ ├── README.md # This documentation file
|
| 31 |
+
│ └── trainnig_output/ # Directory for training artifacts
|
| 32 |
+
│ ├── final_model_2.h5 # Trained fight detection model (relative path)
|
| 33 |
+
│ └── checkpoint/ # Checkpoints saved during training (relative path)
|
| 34 |
+
│ └── training_log.csv # Log file for training history (relative path)
|
| 35 |
+
│ └── yolo/ # (Assumed location)
|
| 36 |
+
│ └── yolo/
|
| 37 |
+
│ └── best.pt # Pre-trained YOLOv8 model weights (relative path)
|
| 38 |
+
├── train/
|
| 39 |
+
│ ├── Fighting/ # Directory containing fight video examples (relative path)
|
| 40 |
+
│ └── Normal/ # Directory containing normal video examples (relative path)
|
| 41 |
+
└── try/
|
| 42 |
+
├── result/ # Directory where output videos are saved (relative path)
|
| 43 |
+
└── ... (Input video files) # Location for input videos (example)
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
*(Note: Model paths and data directories might be hardcoded in the scripts. Consider making these configurable or using relative paths.)*
|
| 47 |
+
|
| 48 |
+
## 4. Setup and Installation
|
| 49 |
+
|
| 50 |
+
**Python Version:**
|
| 51 |
+
|
| 52 |
+
* This project was developed and tested using Python 3.10.
|
| 53 |
+
|
| 54 |
+
**Dependencies:**
|
| 55 |
+
|
| 56 |
+
Based on the code imports and `pip freeze` output, the following libraries and versions were used:
|
| 57 |
+
|
| 58 |
+
* `opencv-python==4.11.0.86` (cv2)
|
| 59 |
+
* `numpy==1.26.4`
|
| 60 |
+
* `tensorflow==2.19.0` (tf)
|
| 61 |
+
* `ultralytics==8.3.108` (for YOLOv8)
|
| 62 |
+
* `matplotlib==3.10.1` (for debug visualizations)
|
| 63 |
+
* `scikit-learn==1.6.1` (sklearn - used in `trainig.py`)
|
| 64 |
+
|
| 65 |
+
*(Note: Other versions might also work, but these are the ones confirmed in the development environment.)*
|
| 66 |
+
|
| 67 |
+
**Installation (using pip):**
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
pip install opencv-python numpy tensorflow ultralytics matplotlib scikit-learn
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
**Models:**
|
| 74 |
+
|
| 75 |
+
1. **Fight Detection Model:** Ensure the trained model (`final_model_2.h5`) is present in the `trainnig_output` subdirectory relative to the script location.
|
| 76 |
+
2. **YOLOv8 Model:** Ensure the YOLO model (`best.pt`) is present in the `yolo/yolo` subdirectory relative to the script location.
|
| 77 |
+
|
| 78 |
+
*(Note: Absolute paths might be hardcoded in the scripts and may need adjustment depending on the deployment environment.)*
|
| 79 |
+
|
| 80 |
+
## 5. Usage
|
| 81 |
+
|
| 82 |
+
To run the analysis on a video file:
|
| 83 |
+
|
| 84 |
+
1. Navigate to the `d:/K_REPO/ComV/AI_made/` directory in your terminal (or ensure Python's working directory is `d:/K_REPO`).
|
| 85 |
+
2. Run the main script:
|
| 86 |
+
```bash
|
| 87 |
+
python full_project.py
|
| 88 |
+
```
|
| 89 |
+
3. The script will prompt you to enter the path to the video file:
|
| 90 |
+
```
|
| 91 |
+
Enter the local path : <your_video_path.mp4>
|
| 92 |
+
```
|
| 93 |
+
*(Ensure you provide the full path, potentially removing extra quotes if copying from Windows Explorer.)*
|
| 94 |
+
|
| 95 |
+
**Output:**
|
| 96 |
+
|
| 97 |
+
* The console will print the fight detection result (e.g., "FIGHT (85.3% confidence)") and information about the object detection process.
|
| 98 |
+
* An annotated video file will be saved in the `D:\K_REPO\ComV\try\result` directory. The filename will include the original video name and the unique detected object labels (e.g., `input_video_label1_label2_output.mp4`).
|
| 99 |
+
* If debug mode is enabled in `Fight_detec_func.py`, additional debug images might be saved in the result directory.
|
| 100 |
+
|
| 101 |
+
## 6. Module Descriptions
|
| 102 |
+
|
| 103 |
+
* **`full_project.py`:** Orchestrates the process by taking user input and calling the fight detection and object detection functions.
|
| 104 |
+
* **`Fight_detec_func.py`:**
|
| 105 |
+
* Contains the `fight_detec` function and `FightDetector` class.
|
| 106 |
+
* Loads the Keras model (`final_model_2.h5`).
|
| 107 |
+
* Uses `frame_slicer` to prepare input for the model.
|
| 108 |
+
* Performs prediction and calculates confidence.
|
| 109 |
+
* Handles debug visualizations.
|
| 110 |
+
* **`objec_detect_yolo.py`:**
|
| 111 |
+
* Contains the `detection` function.
|
| 112 |
+
* Loads the YOLOv8 model (`best.pt`).
|
| 113 |
+
* Iterates through video frames, performs object detection and tracking.
|
| 114 |
+
* Generates and saves the annotated output video.
|
| 115 |
+
* Returns detected object labels.
|
| 116 |
+
* **`frame_slicer.py`:**
|
| 117 |
+
* Contains the `extract_video_frames` utility function.
|
| 118 |
+
* Extracts a fixed number of frames, resizes, normalizes, and handles potential errors during extraction.
|
| 119 |
+
* **`trainig.py`:**
|
| 120 |
+
* Script for training the fight detection model.
|
| 121 |
+
* Includes `VideoDataGenerator` for loading/processing video data.
|
| 122 |
+
* Defines the 3D CNN model architecture.
|
| 123 |
+
* Handles data loading, splitting, training loops, checkpointing, and saving the final model.
|
| 124 |
+
|
| 125 |
+
## 7. Training Data
|
| 126 |
+
|
| 127 |
+
### Dataset Composition
|
| 128 |
+
| Category | Count | Percentage | Formats | Avg Duration |
|
| 129 |
+
|----------------|-------|------------|---------------|--------------|
|
| 130 |
+
| Fight Videos | 2,340 | 61.9% | .mp4, .mpeg | 5.2 sec |
|
| 131 |
+
| Normal Videos | 1,441 | 38.1% | .mp4, .mpeg | 6.1 sec |
|
| 132 |
+
| **Total** | **3,781** | **100%** | | |
|
| 133 |
+
|
| 134 |
+
### Technical Specifications
|
| 135 |
+
- **Resolution:** 64×64 pixels
|
| 136 |
+
- **Color Space:** RGB
|
| 137 |
+
- **Frame Rate:** 30 FPS (average)
|
| 138 |
+
- **Frame Sampling:** 50 frames per video
|
| 139 |
+
- **Input Shape:** (30, 96, 96, 3) - Model resizes input
|
| 140 |
+
|
| 141 |
+
### Data Sources
|
| 142 |
+
- Fighting videos: Collected from public surveillance datasets
|
| 143 |
+
- Normal videos: Sampled from public CCTV footage
|
| 144 |
+
- Manually verified and labeled by domain experts
|
| 145 |
+
|
| 146 |
+
### Preprocessing
|
| 147 |
+
1. Frame extraction at 50 frames/video
|
| 148 |
+
2. Resizing to 96×96 pixels
|
| 149 |
+
3. Normalization (pixel values [0,1])
|
| 150 |
+
4. Temporal sampling to 30 frames for model input
|
| 151 |
+
|
| 152 |
+
## 8. Models Used
|
| 153 |
+
|
| 154 |
+
* **Fight Detection:** A custom 3D CNN trained using `trainig.py`. Located at `D:\K_REPO\ComV\AI_made\trainnig_output\final_model_2.h5`. Input shape expects `(30, 96, 96, 3)` frames.
|
| 155 |
+
* **Object Detection:** YOLOv8 model. Weights located at `D:\K_REPO\ComV\yolo\yolo\best.pt`. This model is trained to detect the following classes: `['Fire', 'Gun', 'License_Plate', 'Smoke', 'knife']`.
|
| 156 |
+
|
| 157 |
+
## 7a. Fight Detection Model Performance
|
| 158 |
+
|
| 159 |
+
The following metrics represent the performance achieved during the training of the `final_model_2.h5`:
|
| 160 |
+
|
| 161 |
+
* **Best Training Accuracy:** 0.8583 (Epoch 7)
|
| 162 |
+
* **Best Validation Accuracy:** 0.9167 (Epoch 10)
|
| 163 |
+
* **Lowest Training Loss:** 0.3636 (Epoch 7)
|
| 164 |
+
* **Lowest Validation Loss:** 0.2805 (Epoch 8)
|
| 165 |
+
|
| 166 |
+
*(Note: These metrics are based on the training run that produced the saved model. Performance may vary slightly on different datasets or during retraining.)*
|
| 167 |
+
|
| 168 |
+
## 8. Configuration
|
| 169 |
+
|
| 170 |
+
Key parameters and paths are mostly hardcoded within the scripts:
|
| 171 |
+
|
| 172 |
+
* `Fight_detec_func.py`: `MODEL_PATH`, `N_FRAMES`, `IMG_SIZE`, `RESULT_PATH`.
|
| 173 |
+
* `objec_detect_yolo.py`: YOLO model path, output directory path (`output_dir`), confidence threshold (`conf=0.7`).
|
| 174 |
+
* `trainig.py`: `DATA_DIR`, `N_FRAMES`, `IMG_SIZE`, `EPOCHS`, `BATCH_SIZE`, `CHECKPOINT_DIR`, `OUTPUT_PATH`.
|
| 175 |
+
|
| 176 |
+
*Recommendation: Refactor these hardcoded values into a separate configuration file (e.g., YAML or JSON) or use command-line arguments for better flexibility.*
|
| 177 |
+
|
| 178 |
+
## 9. Training the Fight Detection Model
|
| 179 |
+
|
| 180 |
+
To retrain or train the fight detection model:
|
| 181 |
+
|
| 182 |
+
1. **Prepare Data:** Place training videos into `D:\K_REPO\ComV\train\Fighting` and `D:\K_REPO\ComV\train\Normal` subdirectories.
|
| 183 |
+
2. **Run Training Script:** Execute `trainig.py`:
|
| 184 |
+
```bash
|
| 185 |
+
python trainig.py
|
| 186 |
+
```
|
| 187 |
+
3. The script will load data, build the model (or resume from a checkpoint if `RESUME_TRAINING=1` and a checkpoint exists), train it, and save the final model to `D:\K_REPO\ComV\AI_made\trainnig_output\final_model_2.h5`. Checkpoints and logs are saved in the `trainnig_output` directory.
|
app.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import tempfile
|
| 4 |
+
from Fight_detec_func import fight_detec
|
| 5 |
+
from objec_detect_yolo import detection
|
| 6 |
+
|
| 7 |
+
def analyze_video(video_file):
|
| 8 |
+
# Save uploaded file to temp location
|
| 9 |
+
temp_dir = tempfile.mkdtemp()
|
| 10 |
+
video_path = os.path.join(temp_dir, video_file.name)
|
| 11 |
+
with open(video_path, 'wb') as f:
|
| 12 |
+
f.write(video_file.read())
|
| 13 |
+
|
| 14 |
+
# Run both detection functions
|
| 15 |
+
fight_result = fight_detec(video_path, debug=False)
|
| 16 |
+
yolo_result = detection(video_path)
|
| 17 |
+
|
| 18 |
+
# Clean up
|
| 19 |
+
os.remove(video_path)
|
| 20 |
+
os.rmdir(temp_dir)
|
| 21 |
+
|
| 22 |
+
return {
|
| 23 |
+
"Fight Detection": fight_result[0],
|
| 24 |
+
"YOLO Object Detection": yolo_result
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
iface = gr.Interface(
|
| 28 |
+
fn=analyze_video,
|
| 29 |
+
inputs=gr.Video(label="Upload Video"),
|
| 30 |
+
outputs=gr.JSON(label="Detection Results"),
|
| 31 |
+
title="Fight and Object Detection System",
|
| 32 |
+
description="Upload a video to detect fights and objects using our AI models"
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
iface.launch()
|
frame_slicer.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
def extract_video_frames(video_path, n_frames=30, frame_size=(96, 96)):
|
| 6 |
+
"""
|
| 7 |
+
Simplified robust frame extractor for short videos (2-10 sec)
|
| 8 |
+
- Automatically handles varying video lengths
|
| 9 |
+
- Ensures consistent output shape
|
| 10 |
+
- Optimized for MP4/MPEG
|
| 11 |
+
"""
|
| 12 |
+
# Open video
|
| 13 |
+
cap = cv2.VideoCapture(video_path)
|
| 14 |
+
if not cap.isOpened():
|
| 15 |
+
print(f"Error: Could not open video {video_path}")
|
| 16 |
+
return None
|
| 17 |
+
|
| 18 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 19 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 20 |
+
|
| 21 |
+
# Basic video validation
|
| 22 |
+
if total_frames < 1 or fps < 1:
|
| 23 |
+
print(f"Error: Invalid video (frames:{total_frames}, fps:{fps})")
|
| 24 |
+
cap.release()
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
# Calculate how many frames to skip (adaptive based on video length)
|
| 28 |
+
video_length = total_frames / fps
|
| 29 |
+
frame_step = max(1, int(total_frames / n_frames))
|
| 30 |
+
|
| 31 |
+
frames = []
|
| 32 |
+
last_good_frame = None
|
| 33 |
+
|
| 34 |
+
for i in range(n_frames):
|
| 35 |
+
# Calculate position to read (spread evenly across video)
|
| 36 |
+
pos = min(int(i * (total_frames / n_frames)), total_frames - 1)
|
| 37 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, pos)
|
| 38 |
+
|
| 39 |
+
ret, frame = cap.read()
|
| 40 |
+
|
| 41 |
+
# Fallback strategies if read fails
|
| 42 |
+
if not ret or frame is None:
|
| 43 |
+
if last_good_frame is not None:
|
| 44 |
+
frame = last_good_frame.copy()
|
| 45 |
+
else:
|
| 46 |
+
# Generate placeholder frame (light gray)
|
| 47 |
+
frame = np.full((*frame_size[::-1], 3), 0.8, dtype=np.float32)
|
| 48 |
+
else:
|
| 49 |
+
# Process valid frame
|
| 50 |
+
frame = cv2.resize(frame, frame_size)
|
| 51 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 52 |
+
frame = frame.astype(np.float32) / 255.0
|
| 53 |
+
last_good_frame = frame
|
| 54 |
+
|
| 55 |
+
frames.append(frame)
|
| 56 |
+
|
| 57 |
+
cap.release()
|
| 58 |
+
return np.array(frames)
|
full_project.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from ultralytics import YOLO
|
| 5 |
+
import time
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
from frame_slicer import extract_video_frames
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
from Fight_detec_func import fight_detec
|
| 11 |
+
from objec_detect_yolo import detection
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Entry point
|
| 16 |
+
path0 = input("Enter the local path : ")
|
| 17 |
+
path = path0.strip('"') # Remove extra quotes if copied from Windows
|
| 18 |
+
print(f"[INFO] Loading video: {path}")
|
| 19 |
+
|
| 20 |
+
fight_detec(path)
|
| 21 |
+
detection(path)
|
| 22 |
+
|
model_summary.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tensorflow.keras.models import load_model
|
| 2 |
+
|
| 3 |
+
model = load_model(r"D:\K_REPO\ComV\AI_made\trainnig_output\final_model_2.h5")
|
| 4 |
+
model.summary()
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from tensorflow.python.client import device_lib
|
| 9 |
+
print("[INFO] Devices available:")
|
| 10 |
+
print(device_lib.list_local_devices())
|
objec_detect_yolo.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
from ultralytics import YOLO
|
| 5 |
+
import time
|
| 6 |
+
from typing import Tuple, Set
|
| 7 |
+
|
| 8 |
+
def detection(path: str) -> Tuple[Set[str], str]:
|
| 9 |
+
"""
|
| 10 |
+
Detects and tracks objects in a video using YOLOv8 model, saving an annotated output video.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
path (str): Path to the input video file. Supports common video formats (mp4, avi, etc.)
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
Tuple[Set[str], str]:
|
| 17 |
+
- Set of unique detected object labels (e.g., {'Gun', 'Knife'})
|
| 18 |
+
- Path to the output annotated video with detection boxes and tracking IDs
|
| 19 |
+
|
| 20 |
+
Raises:
|
| 21 |
+
FileNotFoundError: If input video doesn't exist
|
| 22 |
+
ValueError: If video cannot be opened/processed
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# Validate input file exists
|
| 26 |
+
if not os.path.exists(path):
|
| 27 |
+
raise FileNotFoundError(f"Video file not found: {path}")
|
| 28 |
+
|
| 29 |
+
# Initialize YOLOv8 model with pretrained weights
|
| 30 |
+
# Model is trained to detect: ['Fire', 'Gun', 'License_Plate', 'Smoke', 'knife']
|
| 31 |
+
model = YOLO(os.path.join(os.path.dirname(__file__), "yolo", "best.pt"))
|
| 32 |
+
class_names = model.names # Get class label mappings
|
| 33 |
+
|
| 34 |
+
# Set up output paths:
|
| 35 |
+
# 1. Temporary output during processing
|
| 36 |
+
# 2. Final output with detected objects in filename
|
| 37 |
+
input_video_name = os.path.basename(path)
|
| 38 |
+
base_name = os.path.splitext(input_video_name)[0]
|
| 39 |
+
temp_output_name = f"{base_name}_output_temp.mp4"
|
| 40 |
+
output_dir = "results"
|
| 41 |
+
os.makedirs(output_dir, exist_ok=True) # Create output dir if needed
|
| 42 |
+
if not os.path.exists(output_dir):
|
| 43 |
+
raise ValueError(f"Failed to create output directory: {output_dir}")
|
| 44 |
+
temp_output_path = os.path.join(output_dir, temp_output_name)
|
| 45 |
+
|
| 46 |
+
# Video processing setup:
|
| 47 |
+
# - Open input video stream
|
| 48 |
+
# - Initialize output writer with MP4 codec
|
| 49 |
+
cap = cv2.VideoCapture(path)
|
| 50 |
+
if not cap.isOpened():
|
| 51 |
+
raise ValueError(f"Failed to open video file: {path}")
|
| 52 |
+
|
| 53 |
+
# Process all frames at 640x640 resolution for consistency
|
| 54 |
+
frame_width, frame_height = 640, 640
|
| 55 |
+
out = cv2.VideoWriter(
|
| 56 |
+
temp_output_path,
|
| 57 |
+
cv2.VideoWriter_fourcc(*'mp4v'), # MP4 codec
|
| 58 |
+
30.0, # Output FPS
|
| 59 |
+
(frame_width, frame_height)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Main processing loop:
|
| 63 |
+
# 1. Read each frame
|
| 64 |
+
# 2. Run object detection + tracking
|
| 65 |
+
# 3. Annotate frame with boxes and IDs
|
| 66 |
+
# 4. Collect detected classes
|
| 67 |
+
crimes = [] # Track all detected objects
|
| 68 |
+
start = time.time()
|
| 69 |
+
print(f"[INFO] Processing started at {start:.2f} seconds")
|
| 70 |
+
|
| 71 |
+
while True:
|
| 72 |
+
ret, frame = cap.read()
|
| 73 |
+
if not ret: # End of video
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
# Resize and run detection + tracking
|
| 77 |
+
frame = cv2.resize(frame, (frame_width, frame_height))
|
| 78 |
+
results = model.track(
|
| 79 |
+
source=frame,
|
| 80 |
+
conf=0.7, # Minimum confidence threshold
|
| 81 |
+
persist=True # Enable tracking across frames
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Annotate frame with boxes and tracking IDs
|
| 85 |
+
annotated_frame = results[0].plot()
|
| 86 |
+
|
| 87 |
+
# Record detected classes
|
| 88 |
+
for box in results[0].boxes:
|
| 89 |
+
cls = int(box.cls)
|
| 90 |
+
crimes.append(class_names[cls])
|
| 91 |
+
|
| 92 |
+
out.write(annotated_frame)
|
| 93 |
+
|
| 94 |
+
# Clean up video resources
|
| 95 |
+
end = time.time()
|
| 96 |
+
print(f"[INFO] Processing finished at {end:.2f} seconds")
|
| 97 |
+
print(f"[INFO] Total execution time: {end - start:.2f} seconds")
|
| 98 |
+
cap.release()
|
| 99 |
+
out.release()
|
| 100 |
+
|
| 101 |
+
# Generate final output filename containing detected object labels
|
| 102 |
+
# Format: {original_name}_{detected_objects}_output.mp4
|
| 103 |
+
unique_crimes = set(crimes)
|
| 104 |
+
crimes_str = "_".join(sorted(unique_crimes)).replace(" ", "_")[:50] # truncate if needed
|
| 105 |
+
final_output_name = f"{base_name}_{crimes_str}_output.mp4"
|
| 106 |
+
final_output_path = os.path.join(output_dir, final_output_name)
|
| 107 |
+
|
| 108 |
+
# Rename the video file
|
| 109 |
+
os.rename(temp_output_path, final_output_path)
|
| 110 |
+
|
| 111 |
+
print(f"[INFO] Detected crimes: {unique_crimes}")
|
| 112 |
+
print(f"[INFO] Annotated video saved at: {final_output_path}")
|
| 113 |
+
|
| 114 |
+
return unique_crimes, final_output_path
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# # Entry point
|
| 118 |
+
# path0 = input("Enter the local path to the video file to detect objects: ")
|
| 119 |
+
# path = path0.strip('"') # Remove extra quotes if copied from Windows
|
| 120 |
+
# print(f"[INFO] Loading video: {path}")
|
| 121 |
+
# detection(path)
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=3.0
|
| 2 |
+
tensorflow>=2.10
|
| 3 |
+
opencv-python>=4.6
|
| 4 |
+
ultralytics>=8.0
|
| 5 |
+
numpy>=1.22
|
| 6 |
+
matplotlib>=3.6
|
trainig.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import traceback
|
| 5 |
+
from collections import Counter
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
from tensorflow.keras.utils import Sequence
|
| 8 |
+
from tensorflow.keras.models import Sequential, load_model
|
| 9 |
+
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, Flatten, Dense, Dropout, BatchNormalization
|
| 10 |
+
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
|
| 11 |
+
import tensorflow as tf
|
| 12 |
+
|
| 13 |
+
# === CONFIG ===
|
| 14 |
+
DATA_DIR = "D:\\K_REPO\\ComV\\train"
|
| 15 |
+
N_FRAMES = 30
|
| 16 |
+
IMG_SIZE = (96, 96)
|
| 17 |
+
EPOCHS = 10
|
| 18 |
+
BATCH_SIZE = 14
|
| 19 |
+
CHECKPOINT_DIR = r"D:\K_REPO\ComV\AI_made\trainnig_output\checkpoint"
|
| 20 |
+
RESUME_TRAINING = 1
|
| 21 |
+
MIN_REQUIRED_FRAMES = 10
|
| 22 |
+
OUTPUT_PATH = r"D:\K_REPO\ComV\AI_made\trainnig_output\final_model_2.h5"
|
| 23 |
+
# Optimize OpenCV
|
| 24 |
+
cv2.setUseOptimized(True)
|
| 25 |
+
cv2.setNumThreads(8)
|
| 26 |
+
|
| 27 |
+
# === VIDEO DATA GENERATOR ===
|
| 28 |
+
class VideoDataGenerator(Sequence):
|
| 29 |
+
def __init__(self, video_paths, labels, batch_size, n_frames, img_size):
|
| 30 |
+
self.video_paths, self.labels = self._filter_invalid_videos(video_paths, labels)
|
| 31 |
+
self.batch_size = batch_size
|
| 32 |
+
self.n_frames = n_frames
|
| 33 |
+
self.img_size = img_size
|
| 34 |
+
self.indices = np.arange(len(self.video_paths))
|
| 35 |
+
print(f"[INFO] Final dataset size: {len(self.video_paths)} videos")
|
| 36 |
+
|
| 37 |
+
def _filter_invalid_videos(self, paths, labels):
|
| 38 |
+
valid_paths = []
|
| 39 |
+
valid_labels = []
|
| 40 |
+
|
| 41 |
+
for path, label in zip(paths, labels):
|
| 42 |
+
cap = cv2.VideoCapture(path)
|
| 43 |
+
if not cap.isOpened():
|
| 44 |
+
print(f"[WARNING] Could not open video: {path}")
|
| 45 |
+
continue
|
| 46 |
+
|
| 47 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 48 |
+
cap.release()
|
| 49 |
+
|
| 50 |
+
if total_frames < MIN_REQUIRED_FRAMES:
|
| 51 |
+
print(f"[WARNING] Skipping {path} - only {total_frames} frames (needs at least {MIN_REQUIRED_FRAMES})")
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
valid_paths.append(path)
|
| 55 |
+
valid_labels.append(label)
|
| 56 |
+
|
| 57 |
+
return valid_paths, valid_labels
|
| 58 |
+
|
| 59 |
+
def __len__(self):
|
| 60 |
+
return int(np.ceil(len(self.video_paths) / self.batch_size))
|
| 61 |
+
|
| 62 |
+
def __getitem__(self, index):
|
| 63 |
+
batch_indices = self.indices[index*self.batch_size:(index+1)*self.batch_size]
|
| 64 |
+
X, y = [], []
|
| 65 |
+
|
| 66 |
+
for i in batch_indices:
|
| 67 |
+
path = self.video_paths[i]
|
| 68 |
+
label = self.labels[i]
|
| 69 |
+
try:
|
| 70 |
+
frames = self._load_video_frames(path)
|
| 71 |
+
X.append(frames)
|
| 72 |
+
y.append(label)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"[WARNING] Error processing {path} - {str(e)}")
|
| 75 |
+
X.append(np.zeros((self.n_frames, *self.img_size, 3)))
|
| 76 |
+
y.append(label)
|
| 77 |
+
|
| 78 |
+
return np.array(X), np.array(y)
|
| 79 |
+
|
| 80 |
+
def _load_video_frames(self, path):
|
| 81 |
+
cap = cv2.VideoCapture(path)
|
| 82 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 83 |
+
|
| 84 |
+
if total_frames < self.n_frames:
|
| 85 |
+
frame_indices = np.linspace(0, total_frames - 1, min(total_frames, self.n_frames), dtype=np.int32)
|
| 86 |
+
else:
|
| 87 |
+
frame_indices = np.linspace(0, total_frames - 1, self.n_frames, dtype=np.int32)
|
| 88 |
+
|
| 89 |
+
frames = []
|
| 90 |
+
for idx in frame_indices:
|
| 91 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
| 92 |
+
ret, frame = cap.read()
|
| 93 |
+
if not ret:
|
| 94 |
+
frame = np.zeros((*self.img_size, 3), dtype=np.uint8)
|
| 95 |
+
else:
|
| 96 |
+
frame = cv2.resize(frame, self.img_size)
|
| 97 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 98 |
+
frames.append(frame)
|
| 99 |
+
|
| 100 |
+
cap.release()
|
| 101 |
+
|
| 102 |
+
while len(frames) < self.n_frames:
|
| 103 |
+
frames.append(frames[-1] if frames else np.zeros((*self.img_size, 3), dtype=np.uint8))
|
| 104 |
+
|
| 105 |
+
return np.array(frames) / 255.0
|
| 106 |
+
|
| 107 |
+
def on_epoch_end(self):
|
| 108 |
+
np.random.shuffle(self.indices)
|
| 109 |
+
|
| 110 |
+
def create_model():
|
| 111 |
+
model = Sequential([
|
| 112 |
+
Input(shape=(N_FRAMES, *IMG_SIZE, 3)),
|
| 113 |
+
Conv3D(32, kernel_size=(3, 3, 3), activation='relu', padding='same'),
|
| 114 |
+
MaxPooling3D(pool_size=(1, 2, 2)),
|
| 115 |
+
BatchNormalization(),
|
| 116 |
+
|
| 117 |
+
Conv3D(64, kernel_size=(3, 3, 3), activation='relu', padding='same'),
|
| 118 |
+
MaxPooling3D(pool_size=(1, 2, 2)),
|
| 119 |
+
BatchNormalization(),
|
| 120 |
+
|
| 121 |
+
Conv3D(128, kernel_size=(3, 3, 3), activation='relu', padding='same'),
|
| 122 |
+
MaxPooling3D(pool_size=(2, 2, 2)),
|
| 123 |
+
BatchNormalization(),
|
| 124 |
+
|
| 125 |
+
Flatten(),
|
| 126 |
+
Dense(256, activation='relu'),
|
| 127 |
+
Dropout(0.5),
|
| 128 |
+
Dense(1, activation='sigmoid')
|
| 129 |
+
])
|
| 130 |
+
|
| 131 |
+
model.compile(optimizer='adam',
|
| 132 |
+
loss='binary_crossentropy',
|
| 133 |
+
metrics=['accuracy'])
|
| 134 |
+
|
| 135 |
+
return model
|
| 136 |
+
|
| 137 |
+
def load_data():
|
| 138 |
+
video_paths, labels = [], []
|
| 139 |
+
for label_name in ["Fighting", "Normal"]:
|
| 140 |
+
label_dir = os.path.join(DATA_DIR, label_name)
|
| 141 |
+
if not os.path.isdir(label_dir):
|
| 142 |
+
raise FileNotFoundError(f"Directory not found: {label_dir}")
|
| 143 |
+
|
| 144 |
+
label = 1 if label_name.lower() == "fighting" else 0
|
| 145 |
+
|
| 146 |
+
for file in os.listdir(label_dir):
|
| 147 |
+
if file.lower().endswith((".mp4", ".mpeg", ".avi", ".mov")):
|
| 148 |
+
full_path = os.path.join(label_dir, file)
|
| 149 |
+
video_paths.append(full_path)
|
| 150 |
+
labels.append(label)
|
| 151 |
+
|
| 152 |
+
if not video_paths:
|
| 153 |
+
raise ValueError(f"No videos found in {DATA_DIR}")
|
| 154 |
+
|
| 155 |
+
print(f"[INFO] Total videos: {len(video_paths)} (Fighting: {labels.count(1)}, Normal: {labels.count(0)})")
|
| 156 |
+
|
| 157 |
+
if len(set(labels)) > 1:
|
| 158 |
+
return train_test_split(video_paths, labels, test_size=0.2, stratify=labels, random_state=42)
|
| 159 |
+
else:
|
| 160 |
+
print("[WARNING] Only one class found. Splitting without stratification.")
|
| 161 |
+
return train_test_split(video_paths, labels, test_size=0.2, random_state=42)
|
| 162 |
+
|
| 163 |
+
def get_latest_checkpoint():
|
| 164 |
+
if not os.path.exists(CHECKPOINT_DIR):
|
| 165 |
+
os.makedirs(CHECKPOINT_DIR)
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
checkpoints = [f for f in os.listdir(CHECKPOINT_DIR)
|
| 169 |
+
if f.startswith('ckpt_') and f.endswith('.h5')]
|
| 170 |
+
if not checkpoints:
|
| 171 |
+
return None
|
| 172 |
+
|
| 173 |
+
checkpoints.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
|
| 174 |
+
return os.path.join(CHECKPOINT_DIR, checkpoints[-1])
|
| 175 |
+
|
| 176 |
+
def main():
|
| 177 |
+
# Load and split data
|
| 178 |
+
try:
|
| 179 |
+
train_paths, val_paths, train_labels, val_labels = load_data()
|
| 180 |
+
except Exception as e:
|
| 181 |
+
print(f"[ERROR] Failed to load data: {str(e)}")
|
| 182 |
+
return
|
| 183 |
+
|
| 184 |
+
# Create data generators
|
| 185 |
+
try:
|
| 186 |
+
train_gen = VideoDataGenerator(train_paths, train_labels, BATCH_SIZE, N_FRAMES, IMG_SIZE)
|
| 187 |
+
val_gen = VideoDataGenerator(val_paths, val_labels, BATCH_SIZE, N_FRAMES, IMG_SIZE)
|
| 188 |
+
except Exception as e:
|
| 189 |
+
print(f"[ERROR] Failed to create data generators: {str(e)}")
|
| 190 |
+
return
|
| 191 |
+
|
| 192 |
+
# Callbacks
|
| 193 |
+
callbacks = [
|
| 194 |
+
ModelCheckpoint(
|
| 195 |
+
os.path.join(CHECKPOINT_DIR, 'ckpt_{epoch}.h5'),
|
| 196 |
+
save_best_only=False,
|
| 197 |
+
save_weights_only=False
|
| 198 |
+
),
|
| 199 |
+
CSVLogger('training_log.csv', append=True),
|
| 200 |
+
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
+
# Handle resume training
|
| 204 |
+
initial_epoch = 0
|
| 205 |
+
try:
|
| 206 |
+
if RESUME_TRAINING:
|
| 207 |
+
ckpt = get_latest_checkpoint()
|
| 208 |
+
if ckpt:
|
| 209 |
+
print(f"[INFO] Resuming training from checkpoint: {ckpt}")
|
| 210 |
+
model = load_model(ckpt)
|
| 211 |
+
initial_epoch = int(ckpt.split('_')[1].split('.')[0])
|
| 212 |
+
else:
|
| 213 |
+
print("[INFO] No checkpoint found, starting new training")
|
| 214 |
+
model = create_model()
|
| 215 |
+
else:
|
| 216 |
+
model = create_model()
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print(f"[ERROR] Failed to initialize model: {str(e)}")
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
# Display model summary
|
| 222 |
+
model.summary()
|
| 223 |
+
|
| 224 |
+
# Train model
|
| 225 |
+
try:
|
| 226 |
+
print("[INFO] Starting training...")
|
| 227 |
+
history = model.fit(
|
| 228 |
+
train_gen,
|
| 229 |
+
validation_data=val_gen,
|
| 230 |
+
epochs=EPOCHS,
|
| 231 |
+
initial_epoch=initial_epoch,
|
| 232 |
+
callbacks=callbacks,
|
| 233 |
+
verbose=1
|
| 234 |
+
)
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f"[ERROR] Training failed: {str(e)}")
|
| 237 |
+
traceback.print_exc()
|
| 238 |
+
finally:
|
| 239 |
+
model.save(OUTPUT_PATH)
|
| 240 |
+
print("[INFO] Training completed. Model saved to final_model_2.h5")
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
print("[INFO] Starting script...")
|
| 244 |
+
main()
|
| 245 |
+
print("[INFO] Script execution completed.")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|