RAG / main.py
Hanzo03's picture
inatial commit
ccdd4a4
import argparse
import os
import json
from modules.video_analyzer import analyze_video_for_ppe
from modules.rag_indexer import index_analysis_data
from modules.rag_query import run_query
# --- Configuration ---
RAW_ANALYSIS_FILE = 'raw_analysis.json'
MODEL_PATH = 'yolov8n.pt' # Default YOLOv8 model for general objects
def main():
"""
Executes the full Video Analysis -> Indexing -> Querying RAG pipeline.
"""
parser = argparse.ArgumentParser(
description="Run the full PPE Compliance RAG pipeline.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
'--video-path',
type=str,
required=True,
help="Path to the video file to analyze (e.g., 'construction.mp4')."
)
parser.add_argument(
'--query',
type=str,
required=True,
help="The natural language query to ask the RAG system (e.g., 'Summarize safety violations')."
)
parser.add_argument(
'--frames_per_sec',
type=float,
default=0.5,
help="Number of frames to sample per second for analysis (Default: 0.5)."
)
args = parser.parse_args()
video_path = args.video_path
user_query = args.query
frames_per_sec = args.frames_per_sec
# 1. Check for prerequisites
if not os.path.exists(video_path):
print(f"Error: Video file not found at '{video_path}'.")
return
if not os.path.exists(MODEL_PATH):
print(f"Warning: YOLO model '{MODEL_PATH}' not found. You might need to download it or change MODEL_PATH.")
print("Proceeding, but analysis will likely fail if the model is missing.")
# We allow it to proceed to let the analyzer handle the error
print("="*60)
print("πŸš€ Starting PPE Compliance RAG Pipeline")
print("="*60)
# --- STAGE 1: Video Analysis ---
print(f"\n--- STAGE 1: Analyzing Video '{os.path.basename(video_path)}' ---")
print(f"Sampling Rate: {frames_per_sec} frames/sec")
analysis_results = analyze_video_for_ppe(
video_path=video_path,
model_path=MODEL_PATH,
frames_per_sec=frames_per_sec
)
if not analysis_results:
print("\nAnalysis failed or returned no results. Aborting pipeline.")
return
# Save raw results
with open(RAW_ANALYSIS_FILE, 'w') as f:
json.dump(analysis_results, f, indent=4)
print(f"Raw analysis saved to '{RAW_ANALYSIS_FILE}'. {len(analysis_results)} records created.")
# --- STAGE 2: Data Indexing (RAG Indexer) ---
print("\n--- STAGE 2: Indexing Analysis Data into ChromaDB ---")
# This function expects the file to be named RAW_ANALYSIS_FILE
index_analysis_data(json_file=RAW_ANALYSIS_FILE)
# --- STAGE 3: RAG Query ---
print("\n--- STAGE 3: Executing RAG Query ---")
print(f"User Question: {user_query}")
try:
# Run the RAG query pipeline
rag_answer = run_query(user_query)
print("\n" + "="*60)
print("βœ… RAG Pipeline Complete")
print("="*60)
print("\n--- RAG ANSWER ---")
print(rag_answer)
except Exception as e:
print(f"\nError during RAG Query execution: {e}")
print("Please ensure your environment variables (like GOOGLE_API_KEY) are set and dependencies are installed.")
if __name__ == '__main__':
main()