Multi-Tagger / modules /media_handler.py
Werli's picture
Upload 6 files
b0b0867 verified
import os
import tempfile
from typing import List, Union, Tuple, Optional
from modules.video_processor import is_video_file, process_video_upload, SUPPORTED_VIDEO_FORMATS
# Supported image formats
SUPPORTED_IMAGE_FORMATS = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp', '.gif']
def get_media_type(file_path: str) -> str:
"""
Determine if a file is an image, video, or unsupported.
Args:
file_path: Path to the file
Returns:
'image', 'video', or 'unsupported'
"""
if not file_path:
return 'unsupported'
_, ext = os.path.splitext(file_path.lower())
if ext in SUPPORTED_IMAGE_FORMATS:
return 'image'
elif ext in SUPPORTED_VIDEO_FORMATS:
return 'video'
else:
return 'unsupported'
def is_supported_media(file_path: str) -> bool:
"""Check if the file is a supported image or video format."""
return get_media_type(file_path) in ['image', 'video']
def create_gallery_item(file_path: str) -> Optional[Tuple[str, str]]:
"""
Create a gallery-compatible item from a media file path.
Args:
file_path: Path to the media file
Returns:
Tuple suitable for gallery (file_path, filename) or None if unsupported
"""
if not os.path.exists(file_path):
return None
if not is_supported_media(file_path):
return None
filename = os.path.basename(file_path)
return (file_path, filename)
def process_single_media_upload(file_path: str, max_video_duration: int = 30, frame_interval: int = 1) -> List[str]:
"""
Process a single media file upload (image or video).
Args:
file_path: Path to the uploaded file
max_video_duration: Maximum duration to process for videos (seconds)
frame_interval: Interval between frames for videos (seconds)
Returns:
List of file paths to be added to gallery (images or extracted frames)
"""
if not file_path or not os.path.exists(file_path):
return []
media_type = get_media_type(file_path)
if media_type == 'image':
# For images, just return the original path
return [file_path]
elif media_type == 'video':
# For videos, extract frames
frame_paths, _ = process_video_upload(file_path, max_video_duration, frame_interval)
return frame_paths
else:
# Unsupported format
return []
def process_multiple_media_uploads(
file_paths: List[str],
max_video_duration: int = 30,
frame_interval: int = 1
) -> List[str]:
"""
Process multiple media file uploads.
Args:
file_paths: List of paths to uploaded files
max_video_duration: Maximum duration to process for videos (seconds)
frame_interval: Interval between frames for videos (seconds)
Returns:
List of file paths to be added to gallery (images and extracted frames)
"""
all_paths = []
for file_path in file_paths:
processed_paths = process_single_media_upload(file_path, max_video_duration, frame_interval)
all_paths.extend(processed_paths)
return all_paths
def handle_single_media_upload(file_path: str, gallery: List, max_video_duration: int = 30, frame_interval: int = 1) -> Tuple[List, Optional[str]]:
"""
Handle a single media file upload and update gallery.
Args:
file_path: Path to the uploaded file
gallery: Current gallery list
max_video_duration: Maximum duration to process for videos (seconds)
frame_interval: Interval between frames for videos (seconds)
Returns:
Tuple of (updated_gallery, None) for Gradio compatibility
"""
if gallery is None:
gallery = []
if not file_path:
return gallery, None
# Process the media file
processed_paths = process_single_media_upload(file_path, max_video_duration, frame_interval)
# Create gallery items and add to gallery
for path in processed_paths:
gallery_item = create_gallery_item(path)
if gallery_item:
gallery.append(gallery_item)
return gallery, None
def handle_multiple_media_uploads(
file_paths: List,
gallery: List,
max_video_duration: int = 30,
frame_interval: int = 1
) -> List:
"""
Handle multiple media file uploads and update gallery.
Args:
file_paths: List of uploaded file paths
gallery: Current gallery list
max_video_duration: Maximum duration to process for videos (seconds)
frame_interval: Interval between frames for videos (seconds)
Returns:
Updated gallery list
"""
if gallery is None:
gallery = []
if not file_paths:
return gallery
# Process all media files
processed_paths = process_multiple_media_uploads(file_paths, max_video_duration, frame_interval)
# Create gallery items and add to gallery
for path in processed_paths:
gallery_item = create_gallery_item(path)
if gallery_item:
gallery.append(gallery_item)
return gallery
def get_supported_formats() -> dict:
"""Get dictionary of supported file formats."""
return {
'images': SUPPORTED_IMAGE_FORMATS,
'videos': SUPPORTED_VIDEO_FORMATS,
'all': SUPPORTED_IMAGE_FORMATS + SUPPORTED_VIDEO_FORMATS
}
def validate_media_files(file_paths: List[str]) -> Tuple[List[str], List[str]]:
"""
Validate a list of media files.
Args:
file_paths: List of file paths to validate
Returns:
Tuple of (valid_files, invalid_files)
"""
valid_files = []
invalid_files = []
for file_path in file_paths:
if is_supported_media(file_path):
valid_files.append(file_path)
else:
invalid_files.append(file_path)
return valid_files, invalid_files
# Export functions
__all__ = [
'get_media_type',
'is_supported_media',
'create_gallery_item',
'process_single_media_upload',
'process_multiple_media_uploads',
'handle_single_media_upload',
'handle_multiple_media_uploads',
'get_supported_formats',
'validate_media_files',
'SUPPORTED_IMAGE_FORMATS'
]