Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import hashlib | |
| import shutil | |
| from pathlib import Path | |
| import faiss | |
| import torch | |
| from PIL import Image | |
| from torch import nn | |
| from .prompt.fetch.content_fetch import fetch_links_to_json | |
| from .prompt.fetch.satellite_fetch import fetch_satellite_image | |
| from .prompt.preprocess.keyframe_extract import extract_and_save_keyframes | |
| from .prompt.preprocess.video_transcribe import transcribe_video_directory | |
| from .prompt.search.image_search import image_search_directory | |
| from .prompt.search.index_search import save_results_to_json, search_index_directory | |
| from .prompt.search.text_search import text_search_image, text_search_link | |
| logger = logging.getLogger("uvicorn.error") | |
| class DataProcessor: | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| input_dir: Path, | |
| prompt_dir: Path, | |
| cache_dir: Path, | |
| image_dir: Path, | |
| audio_dir: Path, | |
| index_path: Path, | |
| database_csv_path: Path, | |
| device: torch.device, | |
| ): | |
| self.input_dir = input_dir | |
| self.prompt_dir = prompt_dir | |
| self.cache_dir = cache_dir | |
| self.image_dir = image_dir | |
| self.audio_dir = audio_dir | |
| self.model = model | |
| self.device = device | |
| self.database_csv_path = database_csv_path | |
| try: | |
| self.index = faiss.read_index(str(index_path)) | |
| logger.info(f"✅ Successfully loaded FAISS index from: {index_path}") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load FAISS index from {index_path}: {e}") | |
| self.image_extension = { | |
| ".jpg", | |
| ".jpeg", | |
| ".png", | |
| ".bmp", | |
| ".tiff", | |
| ".tif", | |
| ".webp", | |
| } | |
| self.video_extension = { | |
| ".mp4", | |
| ".avi", | |
| ".mov", | |
| ".mkv", | |
| } | |
| def __extract_keyframes(self): | |
| """ | |
| Extract keyframes from all videos in the input directory. | |
| Put all images and keyframes into the prompt directory. | |
| """ | |
| output_dir = self.image_dir | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Determine starting index based on existing files | |
| current_files = list(output_dir.glob("image_*.*")) | |
| idx = len(current_files) | |
| # Process images | |
| for file_name in os.listdir(self.input_dir): | |
| file_path = os.path.join(self.input_dir, file_name) | |
| if os.path.isfile(file_path) and file_name.lower().endswith( | |
| tuple(self.image_extension) | |
| ): | |
| out_path = output_dir / f"image_{idx:03d}.jpg" | |
| Image.open(file_path).convert("RGB").save(out_path) | |
| idx += 1 | |
| # Process videos | |
| for file_name in os.listdir(self.input_dir): | |
| file_path = os.path.join(self.input_dir, file_name) | |
| if os.path.isfile(file_path) and file_name.lower().endswith( | |
| tuple(self.video_extension) | |
| ): | |
| if idx is None: | |
| idx = 0 | |
| idx = extract_and_save_keyframes( | |
| video_path=file_path, output_dir=str(output_dir), start_index=idx | |
| ) | |
| logger.info(f"✅ Extracted keyframes and images to: {output_dir}") | |
| def __transcribe_videos(self): | |
| """ | |
| Transcribe all videos in the input directory. | |
| Save transcripts into the prompt directory. | |
| """ | |
| audio_dir = self.audio_dir | |
| os.makedirs(audio_dir, exist_ok=True) | |
| if audio_dir.is_dir() and any(audio_dir.iterdir()): | |
| logger.info(f"🔄 Found existing transcripts in directory: {audio_dir}") | |
| return | |
| transcribe_video_directory( | |
| video_dir=str(self.input_dir), | |
| output_dir=str(audio_dir), | |
| model_name="base", # Use the base Whisper model for transcription | |
| ) | |
| logger.info(f"✅ Successfully transcribed videos to: {audio_dir}") | |
| def __image_search(self): | |
| """ | |
| Perform image search on all images in the input directory. | |
| Save search results into the prompt directory. | |
| """ | |
| image_dir = self.image_dir | |
| if os.environ["IMGBB_API_KEY"] is None: | |
| raise ValueError( | |
| "IMGBB_API_KEY environment variable is not set or is None." | |
| ) | |
| if os.environ["SCRAPINGDOG_API_KEY"] is None: | |
| raise ValueError( | |
| "SCRAPINGDOG_API_KEY environment variable is not set or is None." | |
| ) | |
| image_search_directory( | |
| directory=str(image_dir), | |
| output_dir=str(self.prompt_dir), | |
| filename="metadata.json", | |
| imgbb_key=os.environ["IMGBB_API_KEY"], | |
| scrapingdog_key=os.environ["SCRAPINGDOG_API_KEY"], | |
| max_workers=4, | |
| target_links=20, | |
| ) | |
| logger.info(f"✅ Successfully performed image search on: {image_dir}") | |
| def __text_search(self): | |
| """ | |
| Perform text search with metadata to get related links. | |
| """ | |
| query = "" | |
| metadata_file = self.prompt_dir / "metadata.json" | |
| if not metadata_file.exists(): | |
| query = "" | |
| else: | |
| with open(metadata_file, "r") as f: | |
| metadata = json.load(f) | |
| description = metadata.get("description", "") | |
| location = metadata.get("location", "") | |
| query = f"{description} in {location}".strip() | |
| text_search_link( | |
| query=query, | |
| output_dir=str(self.prompt_dir), | |
| filename="text_search.json", | |
| num_results=10, | |
| api_key=os.environ["GOOGLE_CLOUD_API_KEY"], | |
| cx=os.environ["GOOGLE_CSE_CX"], | |
| ) | |
| async def __fetch_related_link_content( | |
| self, image_prediction: bool = True, text_prediction: bool = True | |
| ): | |
| """ | |
| Fetch related link content for all images and text in the prompt directory. | |
| """ | |
| async def fetch_and_save_links(links, output_filename): | |
| if links: | |
| await fetch_links_to_json( | |
| links=list(links), | |
| output_path=str(self.prompt_dir / output_filename), | |
| max_content_length=5000, | |
| ) | |
| logger.info( | |
| f"Fetched content for {len(links)} links into {output_filename}" | |
| ) | |
| # Image links | |
| image_links = set() | |
| image_search_file = self.prompt_dir / "metadata.json" | |
| if image_prediction: | |
| if not image_search_file.exists(): | |
| self.__image_search() | |
| with open(image_search_file, "r") as f: | |
| image_search_data = json.load(f) | |
| image_links.update(image_search_data.get("all_links", [])) | |
| logger.info(f"Found {len(image_links)} image links to fetch content from.") | |
| await fetch_and_save_links(image_links, "image_search_content.json") | |
| # Text links | |
| text_links = set() | |
| text_search_file = self.prompt_dir / "text_search.json" | |
| if text_prediction: | |
| if not text_search_file.exists(): | |
| self.__text_search() | |
| with open(text_search_file, "r") as f: | |
| text_search_data = json.load(f) | |
| text_links.update(filter(None, text_search_data.get("links", []))) | |
| logger.info(f"Found {len(text_links)} text links to fetch content from.") | |
| await fetch_and_save_links(text_links, "text_search_content.json") | |
| if not image_links and not text_links: | |
| logger.info("No links found in image or text search results.") | |
| def __index_search(self): | |
| """ | |
| Perform FAISS index search on all images in the prompt directory. | |
| Save search results into the report directory. | |
| """ | |
| if not self.index: | |
| raise RuntimeError( | |
| "FAISS index is not loaded. Cannot perform index search." | |
| ) | |
| output_path = self.prompt_dir / "index_search.json" | |
| if output_path.exists(): | |
| logger.info( | |
| f"Index search results already exist at {output_path}, skipping search." | |
| ) | |
| return | |
| if not os.path.exists(self.database_csv_path): | |
| raise FileNotFoundError( | |
| f"Database CSV file not found: {self.database_csv_path}" | |
| ) | |
| candidates_gps, reverse_gps = search_index_directory( | |
| model=self.model, | |
| device=self.device, | |
| index=self.index, | |
| image_dir=str(self.image_dir), | |
| database_csv_path=str(self.database_csv_path), | |
| top_k=20, | |
| max_elements=20, | |
| ) | |
| save_results_to_json(candidates_gps, reverse_gps, str(output_path)) | |
| logger.info( | |
| f"✅ Successfully performed index search. Results saved to: {output_path}" | |
| ) | |
| async def __fetch_satellite_image_async( | |
| self, | |
| latitude: float, | |
| longitude: float, | |
| zoom: int, | |
| output_path: Path, | |
| ) -> None: | |
| """ | |
| Asynchronously fetches a satellite image without blocking the event loop. | |
| Runs the synchronous `fetch_satellite_image` function in a background thread. | |
| Args: | |
| latitude (float): Latitude of the location. | |
| longitude (float): Longitude of the location. | |
| zoom (int): Zoom level of the satellite image. | |
| output_path (Path): Path to save the image file. | |
| """ | |
| await asyncio.to_thread( | |
| fetch_satellite_image, | |
| latitude, | |
| longitude, | |
| zoom, | |
| str(output_path), | |
| ) | |
| async def __search_images_async( | |
| self, | |
| location: str, | |
| num_images: int, | |
| api_key: str | None, | |
| cse_cx: str | None, | |
| output_dir: Path, | |
| image_id_offset: int, | |
| ) -> list[str]: | |
| """ | |
| Asynchronously searches for images based on a text location query. | |
| Args: | |
| location (str): Text location to search. | |
| num_images (int): Number of images to fetch. | |
| api_key (str): Google Cloud API key. | |
| cse_cx (str): Google Custom Search Engine ID. | |
| output_dir (Path): Directory where images will be saved. | |
| image_id_offset (int): Offset for image filenames. | |
| Returns: | |
| Any: The result of `text_search_image`, if it returns a value. | |
| """ | |
| return await asyncio.to_thread( | |
| text_search_image, | |
| location, | |
| num_images, | |
| api_key, | |
| cse_cx, | |
| str(output_dir), | |
| image_id_offset, | |
| ) | |
| def __compute_sha256(self, filepath: Path) -> str: | |
| """ | |
| Compute the SHA-256 hash of a file. | |
| """ | |
| if not filepath.is_file(): | |
| raise ValueError(f"File does not exist: {filepath}") | |
| sha256 = hashlib.sha256() | |
| with open(filepath, "rb") as f: | |
| for chunk in iter(lambda: f.read(4096), b""): | |
| sha256.update(chunk) | |
| return sha256.hexdigest() | |
| def __compare_directories(self, dir1: Path, dir2: Path) -> bool: | |
| """ | |
| Compare two directories to check if they contain the same files with identical content. | |
| Args: | |
| dir1 (Path): First directory to compare. | |
| dir2 (Path): Second directory to compare. | |
| Returns: | |
| bool: True if both directories contain the same files with identical content, False otherwise. | |
| """ | |
| if not dir1.is_dir() or not dir2.is_dir(): | |
| return False | |
| files1 = sorted(p for p in dir1.iterdir() if p.is_file()) | |
| files2 = sorted(p for p in dir2.iterdir() if p.is_file()) | |
| # Check if filenames match exactly | |
| names1 = {p.name for p in files1} | |
| names2 = {p.name for p in files2} | |
| if names1 != names2: | |
| return False | |
| # Compare each matching file | |
| for filename in names1: | |
| path1 = dir1 / filename | |
| path2 = dir2 / filename | |
| # Skip directories | |
| if not path1.is_file() or not path2.is_file(): | |
| continue | |
| hash1 = self.__compute_sha256(path1) | |
| hash2 = self.__compute_sha256(path2) | |
| if hash1 != hash2: | |
| return False # Found mismatch | |
| return True # All matching files are identical | |
| def __copy_directory(self, src: Path, dest: Path): | |
| """ | |
| Recursively copy all files from src to dest. | |
| """ | |
| if not src.is_dir(): | |
| raise ValueError(f"Source path is not a directory: {src}") | |
| # Delete everything in dest first | |
| if dest.exists(): | |
| for item in dest.iterdir(): | |
| if item.is_file() or item.is_symlink(): | |
| item.unlink() | |
| elif item.is_dir(): | |
| shutil.rmtree(item) | |
| # Ensure dest exists | |
| dest.mkdir(parents=True, exist_ok=True) | |
| for item in src.iterdir(): | |
| if item.is_dir(): | |
| self.__copy_directory(item, dest / item.name) | |
| else: | |
| dest_file = dest / item.name | |
| if not dest_file.exists() or not self.__compare_directories( | |
| item, dest_file | |
| ): | |
| shutil.copy2(item, dest_file) | |
| async def preprocess_input_data( | |
| self, | |
| image_prediction: bool = True, | |
| text_prediction: bool = True, | |
| ): | |
| """ | |
| Preprocess all input data: | |
| - Extract keyframes from videos. | |
| - Transcribe videos. | |
| - Fetch related link content from images. | |
| Save images and extracted keyframes into the output directory | |
| """ | |
| os.makedirs(self.prompt_dir, exist_ok=True) | |
| os.makedirs(self.cache_dir, exist_ok=True) | |
| cache_dir_input = self.cache_dir / "input_data" | |
| cache_dir_prompt = self.cache_dir / "prompt_data" | |
| if self.__compare_directories(self.input_dir, cache_dir_input): | |
| logger.info("Input data already processed, skipping...") | |
| self.__copy_directory(cache_dir_prompt, self.prompt_dir) | |
| return | |
| else: | |
| logger.info("Processing input data...") | |
| metadata_dest = self.prompt_dir / "metadata.json" | |
| if not metadata_dest.exists(): | |
| for file in os.listdir(self.input_dir): | |
| if file.endswith(".json"): | |
| file_path = os.path.join(self.input_dir, file) | |
| with open(file_path, "r") as src_file: | |
| with open(metadata_dest, "w") as dest_file: | |
| dest_file.write(src_file.read()) | |
| break | |
| self.__extract_keyframes() | |
| self.__transcribe_videos() | |
| await self.__fetch_related_link_content( | |
| image_prediction=image_prediction, text_prediction=text_prediction | |
| ) | |
| self.__index_search() | |
| logger.info("✅ Preprocessing completed") | |
| logger.info(f"Saving processed data to cache directory: {self.cache_dir}") | |
| self.__copy_directory(self.input_dir, cache_dir_input) | |
| self.__copy_directory(self.prompt_dir, cache_dir_prompt) | |
| async def prepare_location_images( | |
| self, | |
| prediction: dict, | |
| image_prediction: bool = True, | |
| text_prediction: bool = True, | |
| ) -> int: | |
| """ | |
| Prepare verification data from the prediction with parallel fetching. | |
| Args: | |
| prediction (dict): Prediction dictionary with latitude, longitude, location, reason, and metadata | |
| image_prediction (bool): Whether to include original images in verification | |
| text_prediction (bool): Whether to include text-based verification | |
| Returns: | |
| int: Satellite image ID for reference in prompts | |
| """ | |
| image_dir = self.image_dir | |
| satellite_image_id = len(list(self.image_dir.glob("image_*.*"))) | |
| # Execute both operations in parallel | |
| logger.info("🔄 Fetching satellite image and location images in parallel...") | |
| # Ensure required API keys are present | |
| if not os.environ.get("GOOGLE_CLOUD_API_KEY"): | |
| raise ValueError( | |
| "GOOGLE_CLOUD_API_KEY environment variable is not set or is None." | |
| ) | |
| if not os.environ.get("GOOGLE_CSE_CX"): | |
| raise ValueError( | |
| "GOOGLE_CSE_CX environment variable is not set or is None." | |
| ) | |
| await asyncio.gather( | |
| self.__fetch_satellite_image_async( | |
| prediction["latitude"], | |
| prediction["longitude"], | |
| zoom=200, | |
| output_path=image_dir / f"image_{satellite_image_id:03d}.jpg", | |
| ), | |
| self.__search_images_async( | |
| location=prediction["location"], | |
| num_images=5, | |
| api_key=os.environ["GOOGLE_CLOUD_API_KEY"], | |
| cse_cx=os.environ["GOOGLE_CSE_CX"], | |
| output_dir=image_dir, | |
| image_id_offset=satellite_image_id + 1, | |
| ), | |
| ) | |
| logger.info("✅ Verification data preparation completed") | |
| return satellite_image_id | |