binuser007 commited on
Commit
9b1cbb7
·
verified ·
1 Parent(s): d3317b5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +346 -0
  2. config.py +67 -0
app.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import tempfile
4
+ import requests
5
+ import base64
6
+ import numpy as np
7
+ import logging
8
+ from dataclasses import dataclass
9
+ from typing import Optional, Union, Tuple
10
+ from PIL import Image
11
+ from io import BytesIO
12
+ from ultralytics import YOLO
13
+ import streamlit as st
14
+ import yt_dlp as youtube_dl
15
+ from config import Config
16
+
17
+ # Configure logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ @dataclass
22
+ class DetectionResult:
23
+ """Data class to store detection results"""
24
+ success: bool
25
+ image: Optional[np.ndarray] = None
26
+ error_message: Optional[str] = None
27
+
28
+ @st.cache_resource
29
+ def load_yolo_model(model_name: str) -> YOLO:
30
+ """Load YOLO model with caching"""
31
+ try:
32
+ if model_name not in Config.AVAILABLE_MODELS:
33
+ raise ValueError(f"Invalid model name: {model_name}")
34
+ return YOLO(model_name)
35
+ except Exception as e:
36
+ logger.error(f"Error loading model: {e}")
37
+ raise RuntimeError(f"Failed to load model: {str(e)}")
38
+
39
+ class YOLOModel:
40
+ """Class to handle YOLO model operations"""
41
+ def __init__(self, model_name: str = Config.DEFAULT_MODEL):
42
+ if not Config.validate_config():
43
+ raise RuntimeError("Invalid configuration")
44
+ self.model = load_yolo_model(model_name)
45
+
46
+ def detect_objects(self, image: np.ndarray) -> DetectionResult:
47
+ """Perform object detection on the input image"""
48
+ if self.model is None:
49
+ return DetectionResult(False, error_message="Model not loaded")
50
+
51
+ try:
52
+ results = self.model(image)
53
+ annotated_image = image.copy()
54
+
55
+ for result in results[0].boxes:
56
+ x1, y1, x2, y2 = map(int, result.xyxy[0])
57
+ label = self.model.names[int(result.cls)]
58
+ confidence = result.conf.item()
59
+
60
+ if confidence < Config.CONFIDENCE_THRESHOLD:
61
+ continue
62
+
63
+ cv2.rectangle(
64
+ annotated_image,
65
+ (x1, y1),
66
+ (x2, y2),
67
+ Config.BBOX_COLOR,
68
+ 2
69
+ )
70
+ label_text = f'{label} {confidence:.2f}'
71
+ cv2.putText(
72
+ annotated_image,
73
+ label_text,
74
+ (x1, y1 - 10),
75
+ cv2.FONT_HERSHEY_SIMPLEX,
76
+ Config.FONT_SCALE,
77
+ Config.BBOX_COLOR,
78
+ Config.FONT_THICKNESS
79
+ )
80
+
81
+ return DetectionResult(True, annotated_image)
82
+ except Exception as e:
83
+ logger.error(f"Error during object detection: {e}")
84
+ return DetectionResult(False, error_message=str(e))
85
+
86
+ class ImageProcessor:
87
+ """Class to handle image processing operations"""
88
+ def __init__(self, model: YOLOModel):
89
+ self.model = model
90
+
91
+ def process_image(self, image: Union[Image.Image, str]) -> DetectionResult:
92
+ """Process image from various sources (PIL Image or URL)"""
93
+ try:
94
+ if isinstance(image, str):
95
+ image = self._load_image_from_url(image)
96
+
97
+ if image is None:
98
+ return DetectionResult(False, error_message="Failed to load image")
99
+
100
+ # Convert image to RGB if it has an alpha channel
101
+ if image.mode == 'RGBA':
102
+ image = image.convert('RGB')
103
+
104
+ np_image = np.array(image)
105
+ return self.model.detect_objects(np_image)
106
+ except Exception as e:
107
+ logger.error(f"Error processing image: {e}")
108
+ return DetectionResult(False, error_message=str(e))
109
+
110
+ def _load_image_from_url(self, url: str) -> Optional[Image.Image]:
111
+ """Load image from URL with support for base64"""
112
+ try:
113
+ if url.startswith('data:image'):
114
+ header, encoded = url.split(',', 1)
115
+ image_data = base64.b64decode(encoded)
116
+ return Image.open(BytesIO(image_data))
117
+ else:
118
+ response = requests.get(url)
119
+ response.raise_for_status()
120
+ return Image.open(BytesIO(response.content))
121
+ except Exception as e:
122
+ logger.error(f"Error loading image from URL: {e}")
123
+ return None
124
+
125
+ class VideoProcessor:
126
+ """Class to handle video processing operations"""
127
+ def __init__(self, model: YOLOModel):
128
+ self.model = model
129
+ os.makedirs(Config.TEMP_DIR, exist_ok=True)
130
+
131
+ def process_video(self, input_path: str) -> Tuple[bool, Optional[str]]:
132
+ """Process video file and return path to processed video"""
133
+ cap = None
134
+ writer = None
135
+ progress_bar = st.progress(0)
136
+ status_text = st.empty()
137
+
138
+ try:
139
+ cap = cv2.VideoCapture(input_path)
140
+ if not cap.isOpened():
141
+ return False, "Cannot open video file"
142
+
143
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
144
+ if total_frames <= 0:
145
+ return False, "Invalid video file"
146
+
147
+ output_path = os.path.join(Config.TEMP_DIR, "processed_video.mp4")
148
+ writer = self._setup_video_writer(cap, output_path)
149
+
150
+ frame_count = 0
151
+ while True:
152
+ ret, frame = cap.read()
153
+ if not ret:
154
+ break
155
+
156
+ frame_count += 1
157
+ progress = min(100, int(frame_count * 100 / total_frames))
158
+ progress_bar.progress(progress)
159
+ status_text.text(f"Processing frame {frame_count}/{total_frames}")
160
+
161
+ result = self.model.detect_objects(frame)
162
+ if result.success:
163
+ writer.write(result.image)
164
+
165
+ return True, output_path
166
+ except Exception as e:
167
+ logger.error(f"Error processing video: {e}")
168
+ return False, str(e)
169
+ finally:
170
+ if cap is not None:
171
+ cap.release()
172
+ if writer is not None:
173
+ writer.release()
174
+ progress_bar.empty()
175
+ status_text.empty()
176
+
177
+ def _setup_video_writer(self, cap: cv2.VideoCapture, output_path: str) -> cv2.VideoWriter:
178
+ """Set up video writer with input video properties"""
179
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
180
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
181
+ fps = cap.get(cv2.CAP_PROP_FPS)
182
+ fourcc = cv2.VideoWriter_fourcc(*Config.VIDEO_OUTPUT_FORMAT)
183
+ return cv2.VideoWriter(output_path, fourcc, fps, (width, height))
184
+
185
+ def download_youtube_video(youtube_url: str) -> Optional[str]:
186
+ """Download YouTube video and return path to downloaded file"""
187
+ try:
188
+ temp_dir = tempfile.gettempdir()
189
+ output_path = os.path.join(temp_dir, 'downloaded_video.mp4')
190
+ ydl_opts = {
191
+ 'format': 'best',
192
+ 'outtmpl': output_path
193
+ }
194
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
195
+ ydl.download([youtube_url])
196
+ return output_path
197
+ except Exception as e:
198
+ logger.error(f"Failed to retrieve video from YouTube: {e}")
199
+ return None
200
+
201
+ def cleanup_temp_files():
202
+ """Clean up temporary files"""
203
+ try:
204
+ for file in os.listdir(Config.TEMP_DIR):
205
+ file_path = os.path.join(Config.TEMP_DIR, file)
206
+ try:
207
+ if os.path.isfile(file_path):
208
+ os.unlink(file_path)
209
+ except Exception as e:
210
+ logger.error(f"Error deleting {file_path}: {e}")
211
+ except Exception as e:
212
+ logger.error(f"Error cleaning up temp directory: {e}")
213
+
214
+ def validate_image(image: Image.Image) -> Tuple[bool, str]:
215
+ """Validate image format and properties"""
216
+ try:
217
+ # Check image mode
218
+ if image.mode not in ['RGB', 'RGBA']:
219
+ return False, f"Unsupported image mode: {image.mode}"
220
+
221
+ # Check image size
222
+ max_dimension = 1920
223
+ width, height = image.size
224
+ if width > max_dimension or height > max_dimension:
225
+ return False, f"Image too large. Maximum dimension: {max_dimension}px"
226
+
227
+ # Check if image is valid
228
+ image.verify()
229
+ return True, "Image is valid"
230
+ except Exception as e:
231
+ return False, str(e)
232
+
233
+ def main():
234
+ """Main application function"""
235
+ st.title("MULTIMEDIA OBJECT DETECTION USING YOLO")
236
+
237
+ # Model selection with description
238
+ st.subheader("Model Selection")
239
+ model_choice = st.selectbox(
240
+ "Select YOLO Model",
241
+ options=Config.AVAILABLE_MODELS,
242
+ index=Config.AVAILABLE_MODELS.index(Config.DEFAULT_MODEL),
243
+ format_func=lambda x: f"{x} - {Config.YOLO_MODELS[x]}"
244
+ )
245
+
246
+ # Initialize model using session state
247
+ if 'model' not in st.session_state or st.session_state.get('model_choice') != model_choice:
248
+ try:
249
+ st.session_state.model = YOLOModel(model_choice)
250
+ st.session_state.model_choice = model_choice
251
+ except Exception as e:
252
+ st.error(f"Error loading model: {str(e)}")
253
+ return
254
+
255
+ model = st.session_state.model
256
+ image_processor = ImageProcessor(model)
257
+ video_processor = VideoProcessor(model)
258
+
259
+ # Display model capabilities
260
+ model_type = "Detection"
261
+ if "pose" in model_choice:
262
+ model_type = "Pose Estimation"
263
+ st.info("This model will detect and estimate human poses in the image/video.")
264
+ elif "seg" in model_choice:
265
+ model_type = "Instance Segmentation"
266
+ st.info("This model will perform instance segmentation, creating precise masks for detected objects.")
267
+ else:
268
+ st.info("This model will detect and classify objects with bounding boxes.")
269
+
270
+ tabs = st.tabs(["Image Detection", "Video Detection"])
271
+
272
+ with tabs[0]:
273
+ st.header("Image Detection")
274
+ input_choice = st.radio("Select Input Method", ["Upload", "URL"])
275
+
276
+ if input_choice == "Upload":
277
+ uploaded_image = st.file_uploader(
278
+ "Upload Image",
279
+ type=Config.ALLOWED_IMAGE_TYPES
280
+ )
281
+ if uploaded_image is not None:
282
+ image = Image.open(uploaded_image)
283
+ result = image_processor.process_image(image)
284
+ if result.success:
285
+ st.image(result.image, caption="Processed Image", use_container_width=True)
286
+ else:
287
+ st.error(result.error_message)
288
+
289
+ elif input_choice == "URL":
290
+ image_url = st.text_input("Image URL")
291
+ if image_url:
292
+ result = image_processor.process_image(image_url)
293
+ if result.success:
294
+ st.image(result.image, caption="Processed Image", use_container_width=True)
295
+ else:
296
+ st.error(result.error_message)
297
+
298
+ with tabs[1]:
299
+ st.header("Video Detection")
300
+ video_choice = st.radio("Select Input Method", ["Upload", "YouTube"])
301
+
302
+ if video_choice == "Upload":
303
+ try:
304
+ uploaded_video = st.file_uploader(
305
+ "Upload Local Video",
306
+ type=Config.ALLOWED_VIDEO_TYPES
307
+ )
308
+ if uploaded_video is not None:
309
+ if uploaded_video.size > 200 * 1024 * 1024: # 200MB limit
310
+ st.error("Video file is too large. Please upload a file smaller than 200MB.")
311
+ return
312
+
313
+ input_video_path = os.path.join(Config.TEMP_DIR, uploaded_video.name)
314
+ with open(input_video_path, "wb") as f:
315
+ f.write(uploaded_video.read())
316
+
317
+ try:
318
+ success, result = video_processor.process_video(input_video_path)
319
+ if success:
320
+ st.video(result)
321
+ else:
322
+ st.error(f"Error processing video: {result}")
323
+ finally:
324
+ cleanup_temp_files()
325
+ except Exception as e:
326
+ st.error(f"An error occurred: {str(e)}")
327
+
328
+ elif video_choice == "YouTube":
329
+ video_url = st.text_input("YouTube Video URL")
330
+ if video_url:
331
+ with st.spinner("Downloading video..."):
332
+ input_video_path = download_youtube_video(video_url)
333
+ if input_video_path:
334
+ try:
335
+ success, result = video_processor.process_video(input_video_path)
336
+ if success:
337
+ st.video(result)
338
+ else:
339
+ st.error(result)
340
+ finally:
341
+ cleanup_temp_files()
342
+ else:
343
+ st.error("Failed to download YouTube video")
344
+
345
+ if __name__ == "__main__":
346
+ main()
config.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ # Load environment variables
6
+ load_dotenv()
7
+
8
+ class Config:
9
+ # Model configurations with descriptions
10
+ YOLO_MODELS = {
11
+ "yolov8n.pt": "YOLOv8 Nano - Fastest and smallest model, best for CPU/edge devices",
12
+ "yolov8s.pt": "YOLOv8 Small - Good balance of speed and accuracy",
13
+ "yolov8m.pt": "YOLOv8 Medium - Better accuracy, still reasonable speed",
14
+ "yolov8l.pt": "YOLOv8 Large - High accuracy, slower speed",
15
+ "yolov8x.pt": "YOLOv8 XLarge - Highest accuracy, slowest speed",
16
+ # Pose estimation models
17
+ "yolov8n-pose.pt": "YOLOv8 Nano Pose - Fast pose estimation",
18
+ "yolov8s-pose.pt": "YOLOv8 Small Pose - Balanced pose estimation",
19
+ "yolov8m-pose.pt": "YOLOv8 Medium Pose - Accurate pose estimation",
20
+ "yolov8l-pose.pt": "YOLOv8 Large Pose - High accuracy pose estimation",
21
+ "yolov8x-pose.pt": "YOLOv8 XLarge Pose - Most accurate pose estimation",
22
+ # Segmentation models
23
+ "yolov8n-seg.pt": "YOLOv8 Nano Segmentation - Fast instance segmentation",
24
+ "yolov8s-seg.pt": "YOLOv8 Small Segmentation - Balanced segmentation",
25
+ "yolov8m-seg.pt": "YOLOv8 Medium Segmentation - Accurate segmentation",
26
+ "yolov8l-seg.pt": "YOLOv8 Large Segmentation - High accuracy segmentation",
27
+ "yolov8x-seg.pt": "YOLOv8 XLarge Segmentation - Most accurate segmentation"
28
+ }
29
+
30
+ AVAILABLE_MODELS: List[str] = list(YOLO_MODELS.keys())
31
+ DEFAULT_MODEL: str = os.getenv('DEFAULT_MODEL', 'yolov8s.pt')
32
+
33
+ # File configurations
34
+ ALLOWED_IMAGE_TYPES: List[str] = ["jpg", "jpeg", "png"]
35
+ ALLOWED_VIDEO_TYPES: List[str] = ["mp4", "mov", "avi"]
36
+
37
+ # Video processing
38
+ TEMP_DIR: str = os.getenv('TEMP_DIR', 'temp')
39
+ VIDEO_OUTPUT_FORMAT: str = os.getenv('VIDEO_OUTPUT_FORMAT', 'mp4v')
40
+ MAX_VIDEO_DURATION: int = int(os.getenv('MAX_VIDEO_DURATION', '300')) # 5 minutes default
41
+
42
+ # UI configurations
43
+ CONFIDENCE_THRESHOLD: float = float(os.getenv('CONFIDENCE_THRESHOLD', '0.25'))
44
+ BBOX_COLOR: tuple = tuple(map(int, os.getenv('BBOX_COLOR', '0,255,0').split(',')))
45
+ FONT_SCALE: float = float(os.getenv('FONT_SCALE', '0.5'))
46
+ FONT_THICKNESS: int = int(os.getenv('FONT_THICKNESS', '2'))
47
+
48
+ # Cache settings
49
+ CACHE_DIR: str = os.getenv('CACHE_DIR', '.cache')
50
+ MAX_CACHE_SIZE: int = int(os.getenv('MAX_CACHE_SIZE', '1024')) # MB
51
+
52
+ @classmethod
53
+ def validate_config(cls) -> bool:
54
+ """Validate configuration settings"""
55
+ try:
56
+ # Validate model exists
57
+ if cls.DEFAULT_MODEL not in cls.AVAILABLE_MODELS:
58
+ raise ValueError(f"Invalid default model: {cls.DEFAULT_MODEL}")
59
+
60
+ # Validate directories exist or can be created
61
+ os.makedirs(cls.TEMP_DIR, exist_ok=True)
62
+ os.makedirs(cls.CACHE_DIR, exist_ok=True)
63
+
64
+ return True
65
+ except Exception as e:
66
+ print(f"Configuration validation failed: {str(e)}")
67
+ return False