Hanzo03 commited on
Commit
ccdd4a4
·
1 Parent(s): e977014

inatial commit

Browse files
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ GEMINI_API_KEY = "your_gemini_api_key_here"
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ # Python-generated files
37
+ __pycache__/
38
+ *.py[oc]
39
+ build/
40
+ dist/
41
+ wheels/
42
+ *.egg-info
43
+
44
+ # Virtual environments
45
+ .venv
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ======================== app.py ========================
2
+
3
+ import gradio as gr
4
+ import os
5
+ import json
6
+ import shutil # Used for file operations
7
+
8
+ # Import the core functions from your existing files
9
+ from modules.video_analyzer import analyze_video_for_ppe
10
+ from modules.rag_indexer import index_analysis_data
11
+ from modules.rag_query import run_query
12
+
13
+ # Configuration (Keep consistent with rag_indexer.py and rag_query.py)
14
+ VIDEO_FILENAME = "uploaded_video.mp4" # Temp name for the uploaded file
15
+ RAW_ANALYSIS_FILE = 'raw_analysis.json'
16
+ DB_PATH = "./chroma_db"
17
+ COLLECTION_NAME = 'video_analysis_data' # Use the general collection name
18
+
19
+ def pipeline_fn(video_file, user_query):
20
+ """
21
+ The main function connecting the Gradio inputs to the RAG pipeline.
22
+
23
+ Args:
24
+ video_file: The temporary file object from Gradio (gr.File).
25
+ user_query: The text question from Gradio (gr.Textbox).
26
+
27
+ Returns:
28
+ The text response from the RAG query.
29
+ """
30
+ if video_file is None:
31
+ return "Error: Please upload a video file first."
32
+ if not user_query:
33
+ return "Error: Please enter a question to query the video analysis."
34
+
35
+ # 1. Handle File Upload and Naming
36
+ # Gradio passes a temporary file path, we need to copy it
37
+ # and rename it so the analyzer can find it consistently.
38
+ try:
39
+ # We copy the file to the expected working directory
40
+ temp_video_path = os.path.join(os.getcwd(), VIDEO_FILENAME)
41
+ shutil.copy(video_file.name, temp_video_path)
42
+ print(f"Copied uploaded file to: {temp_video_path}")
43
+ except Exception as e:
44
+ return f"File handling error: {e}"
45
+
46
+ # 2. Analyze Video
47
+ print("\n--- STAGE 1: Analyzing Video ---")
48
+ # frames_per_sec=0.5 is a sensible default for a quick demo
49
+ analysis_results = analyze_video_for_ppe(
50
+ video_path=temp_video_path,
51
+ frames_per_sec= 2
52
+ )
53
+
54
+ # Save the raw analysis for the indexer to pick up
55
+ with open(RAW_ANALYSIS_FILE, 'w') as f:
56
+ json.dump(analysis_results, f, indent=4)
57
+
58
+ # 3. Index Data
59
+ print("\n--- STAGE 2: Indexing Analysis Data ---")
60
+ # This must be run to create/update the ChromaDB with the new analysis
61
+ index_analysis_data(json_file=RAW_ANALYSIS_FILE, collection_name=COLLECTION_NAME)
62
+
63
+ # 4. Execute RAG Query
64
+ print("\n--- STAGE 3: Executing RAG Query ---")
65
+ rag_answer = run_query(user_query)
66
+
67
+ # 5. Cleanup (Optional but Recommended for Demo)
68
+ os.remove(temp_video_path) # Remove the copied video file
69
+ os.remove(RAW_ANALYSIS_FILE) # Remove the temporary JSON file
70
+
71
+ return rag_answer
72
+
73
+ # --- Gradio Interface Definition ---
74
+
75
+ # Define the input components
76
+ video_input = gr.File(
77
+ label="Upload Video File (.mp4, .mov, etc.)",
78
+ file_types=["video"], # Restrict to video files
79
+ type="filepath"
80
+ )
81
+ query_input = gr.Textbox(
82
+ label="Ask a Question about the Video Content",
83
+ placeholder="e.g., What are people doing in the video?",
84
+ lines=2
85
+ )
86
+
87
+ # Define the output component
88
+ output_textbox = gr.Textbox(
89
+ label="RAG Analysis Result",
90
+ lines=10,
91
+ interactive=False
92
+ )
93
+
94
+ # Create the Gradio Interface
95
+ demo = gr.Interface(
96
+ fn=pipeline_fn,
97
+ inputs=[video_input, query_input],
98
+ outputs=output_textbox,
99
+ title="🚀 Video Content RAG Pipeline",
100
+ description="Upload a video, and ask a question. The pipeline runs object detection, indexes the data, and uses Gemini to answer your question based on the analysis.",
101
+ )
102
+
103
+ if __name__ == "__main__":
104
+ print("Launching Gradio App...")
105
+ # This will open the app in your browser at http://127.0.0.1:7860/
106
+ demo.launch()
main.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import json
4
+ from modules.video_analyzer import analyze_video_for_ppe
5
+ from modules.rag_indexer import index_analysis_data
6
+ from modules.rag_query import run_query
7
+
8
+ # --- Configuration ---
9
+ RAW_ANALYSIS_FILE = 'raw_analysis.json'
10
+ MODEL_PATH = 'yolov8n.pt' # Default YOLOv8 model for general objects
11
+
12
+ def main():
13
+ """
14
+ Executes the full Video Analysis -> Indexing -> Querying RAG pipeline.
15
+ """
16
+ parser = argparse.ArgumentParser(
17
+ description="Run the full PPE Compliance RAG pipeline.",
18
+ formatter_class=argparse.RawTextHelpFormatter
19
+ )
20
+
21
+ parser.add_argument(
22
+ '--video-path',
23
+ type=str,
24
+ required=True,
25
+ help="Path to the video file to analyze (e.g., 'construction.mp4')."
26
+ )
27
+
28
+ parser.add_argument(
29
+ '--query',
30
+ type=str,
31
+ required=True,
32
+ help="The natural language query to ask the RAG system (e.g., 'Summarize safety violations')."
33
+ )
34
+
35
+ parser.add_argument(
36
+ '--frames_per_sec',
37
+ type=float,
38
+ default=0.5,
39
+ help="Number of frames to sample per second for analysis (Default: 0.5)."
40
+ )
41
+
42
+ args = parser.parse_args()
43
+
44
+ video_path = args.video_path
45
+ user_query = args.query
46
+ frames_per_sec = args.frames_per_sec
47
+
48
+ # 1. Check for prerequisites
49
+ if not os.path.exists(video_path):
50
+ print(f"Error: Video file not found at '{video_path}'.")
51
+ return
52
+
53
+ if not os.path.exists(MODEL_PATH):
54
+ print(f"Warning: YOLO model '{MODEL_PATH}' not found. You might need to download it or change MODEL_PATH.")
55
+ print("Proceeding, but analysis will likely fail if the model is missing.")
56
+ # We allow it to proceed to let the analyzer handle the error
57
+
58
+ print("="*60)
59
+ print("🚀 Starting PPE Compliance RAG Pipeline")
60
+ print("="*60)
61
+
62
+ # --- STAGE 1: Video Analysis ---
63
+ print(f"\n--- STAGE 1: Analyzing Video '{os.path.basename(video_path)}' ---")
64
+ print(f"Sampling Rate: {frames_per_sec} frames/sec")
65
+
66
+ analysis_results = analyze_video_for_ppe(
67
+ video_path=video_path,
68
+ model_path=MODEL_PATH,
69
+ frames_per_sec=frames_per_sec
70
+ )
71
+
72
+ if not analysis_results:
73
+ print("\nAnalysis failed or returned no results. Aborting pipeline.")
74
+ return
75
+
76
+ # Save raw results
77
+ with open(RAW_ANALYSIS_FILE, 'w') as f:
78
+ json.dump(analysis_results, f, indent=4)
79
+ print(f"Raw analysis saved to '{RAW_ANALYSIS_FILE}'. {len(analysis_results)} records created.")
80
+
81
+ # --- STAGE 2: Data Indexing (RAG Indexer) ---
82
+ print("\n--- STAGE 2: Indexing Analysis Data into ChromaDB ---")
83
+ # This function expects the file to be named RAW_ANALYSIS_FILE
84
+ index_analysis_data(json_file=RAW_ANALYSIS_FILE)
85
+
86
+ # --- STAGE 3: RAG Query ---
87
+ print("\n--- STAGE 3: Executing RAG Query ---")
88
+ print(f"User Question: {user_query}")
89
+
90
+ try:
91
+ # Run the RAG query pipeline
92
+ rag_answer = run_query(user_query)
93
+
94
+ print("\n" + "="*60)
95
+ print("✅ RAG Pipeline Complete")
96
+ print("="*60)
97
+ print("\n--- RAG ANSWER ---")
98
+ print(rag_answer)
99
+
100
+ except Exception as e:
101
+ print(f"\nError during RAG Query execution: {e}")
102
+ print("Please ensure your environment variables (like GOOGLE_API_KEY) are set and dependencies are installed.")
103
+
104
+ if __name__ == '__main__':
105
+ main()
modules/__pycache__/rag_indexer.cpython-312.pyc ADDED
Binary file (3.28 kB). View file
 
modules/__pycache__/rag_query.cpython-312.pyc ADDED
Binary file (3.44 kB). View file
 
modules/__pycache__/video_analyzer.cpython-312.pyc ADDED
Binary file (3.32 kB). View file
 
modules/rag_indexer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ import json
3
+
4
+ COLLECTION_NAME = 'video_analysis_data'
5
+ # ... (rest of imports/constants)
6
+
7
+ def generate_text_summary(record):
8
+ """
9
+ Converts a structured detection record into a natural language text description
10
+ by summarizing all detected objects clearly.
11
+ """
12
+ video_id = record['video_id']
13
+ timestamp = record['timestamp_sec']
14
+ detections = record['detections']
15
+
16
+ if not detections:
17
+ return f"Analysis of video '{video_id}' at {timestamp} seconds: No objects were detected in this frame."
18
+
19
+ # Group detections by label for a complete object count summary
20
+ object_counts = {}
21
+ for det in detections:
22
+ label = det['label']
23
+ object_counts[label] = object_counts.get(label, 0) + 1
24
+
25
+ summary_parts = []
26
+
27
+ if object_counts:
28
+ # Format: N instances of 'label', M instances of 'other_label', etc.
29
+ object_descriptions = [
30
+ f"{count} instances of '{label}'"
31
+ for label, count in object_counts.items()
32
+ ]
33
+ summary_parts.append("Detected objects include: " + ", ".join(object_descriptions) + ".")
34
+
35
+ summary_doc = f"Analysis of video '{video_id}' at {timestamp} seconds: {' '.join(summary_parts)}"
36
+ return summary_doc
37
+
38
+
39
+ def index_analysis_data(json_file='raw_analysis.json', collection_name='video_analysis_data'):
40
+ """
41
+ Loads raw analysis, generates documents, and indexes them in ChromaDB.
42
+ """
43
+ try:
44
+ with open(json_file, 'r') as f:
45
+ raw_data = json.load(f)
46
+ except FileNotFoundError:
47
+ print(f"Error: {json_file} not found. Run 'video_analyzer.py' first.")
48
+ return
49
+
50
+ # Initialize ChromaDB client
51
+ client = chromadb.PersistentClient(path="./chroma_db") # Stores data locally
52
+ # Changed collection name to be more generic
53
+ collection = client.get_or_create_collection(name=collection_name)
54
+
55
+ documents = []
56
+ metadatas = []
57
+ ids = []
58
+
59
+ print(f"Indexing {len(raw_data)} analysis records...")
60
+
61
+ for i, record in enumerate(raw_data):
62
+ doc_text = generate_text_summary(record)
63
+ if doc_text:
64
+ documents.append(doc_text)
65
+ # Metadata is crucial for filtering and context
66
+ metadatas.append({
67
+ 'video_id': record['video_id'],
68
+ 'timestamp_sec': record['timestamp_sec'],
69
+ 'frame_id': record['frame_id']
70
+ })
71
+ ids.append(f"doc_{i}")
72
+
73
+ # ChromaDB automatically handles embedding and storage
74
+ if documents:
75
+ collection.add(
76
+ documents=documents,
77
+ metadatas=metadatas,
78
+ ids=ids
79
+ )
80
+ print(f"Successfully indexed {len(documents)} documents into ChromaDB collection '{collection_name}'.")
81
+ else:
82
+ print("No valid documents generated for indexing.")
83
+
84
+
85
+ if __name__ == '__main__':
86
+ index_analysis_data()
modules/rag_query.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from langchain_google_genai import ChatGoogleGenerativeAI
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+ from langchain_core.runnables import RunnablePassthrough
5
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
6
+ from langchain_chroma import Chroma
7
+ import os
8
+
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv() # Load environment variables from .env file
12
+
13
+ # 1. Initialize RAG Components
14
+ GEMINI_MODEL = "gemini-2.5-flash"
15
+ COLLECTION_NAME = 'video_analysis_data' # Updated collection name
16
+ DB_PATH = "./chroma_db"
17
+
18
+ def run_query(user_query):
19
+ """
20
+ Executes the RAG pipeline: Retrieve relevant context and generate a general answer.
21
+ """
22
+ if not os.path.exists(DB_PATH):
23
+ print(f"Error: Database path {DB_PATH} not found. Run 'rag_indexer.py' first.")
24
+ return "Analysis data is not yet indexed. Please index the data first."
25
+
26
+ # ChromaDB setup
27
+ client = chromadb.PersistentClient(path=DB_PATH)
28
+
29
+ embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
30
+ vectorstore = Chroma(
31
+ client=client,
32
+ collection_name=COLLECTION_NAME,
33
+ embedding_function=embedding_function
34
+ )
35
+
36
+ # 2. Retrieval (R)
37
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) # Retrieve top 5 relevant documents
38
+
39
+ # 3. Generation (G)
40
+ # The new general-purpose system prompt
41
+ template = """
42
+ You are an expert Video Content Analyst. Your task is to describe the scene, objects, and potential context detected in a video based on the provided analysis logs (Context).
43
+ Use the Context to answer the user's Question.
44
+
45
+ IMPORTANT: Since the analysis only provides object labels and locations (Object Detection) and not specific actions (Activity Recognition), you must infer and describe the *scene* or *context* based on the objects present (e.g., "The presence of people and construction materials suggests activity on a construction site.").
46
+
47
+ Always include the video timestamp(s) where the relevant context was found.
48
+
49
+ Context:
50
+ {context}
51
+
52
+ Question: {question}
53
+ """
54
+ prompt = ChatPromptTemplate.from_template(template)
55
+
56
+ llm = ChatGoogleGenerativeAI(model=GEMINI_MODEL)
57
+
58
+ # The RAG Chain logic
59
+ rag_chain = (
60
+ {"context": retriever, "question": RunnablePassthrough()}
61
+ | prompt
62
+ | llm
63
+ )
64
+
65
+ # 4. Execute the chain
66
+ print(f"Executing RAG query for: '{user_query}'...")
67
+ response = rag_chain.invoke(user_query)
68
+
69
+ return response.content
70
+
71
+ # Example Usage:
72
+ if __name__ == '__main__':
73
+ # Ensure you have indexed data by running rag_indexer.py first
74
+
75
+ query1 = "What kind of objects were frequently detected in the video?"
76
+ answer1 = run_query(query1)
77
+ print("\n--- QUERY 1 ---")
78
+ print(f"Question: {query1}")
79
+ print(f"Answer:\n{answer1}")
80
+
81
+ print("\n" + "="*50 + "\n")
82
+
83
+ query2 = "What activity was detected around the 15-second mark in the video?"
84
+ answer2 = run_query(query2)
85
+ print("\n--- QUERY 2 ---")
86
+ print(f"Question: {query2}")
87
+ print(f"Answer:\n{answer2}")
modules/video_analyzer.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from ultralytics import YOLO
3
+ import json
4
+ import os
5
+
6
+ def analyze_video_for_ppe(video_path, model_path='yolov8n.pt', frames_per_sec=1.0):
7
+ """
8
+ Analyzes a video for PPE compliance using a YOLOv8 model.
9
+ """
10
+ # 1. Load the YOLOv8 model (You'd replace 'yolov8n.pt' with a fine-tuned PPE model)
11
+ # The search results indicate YOLOv8 is excellent for this.
12
+ try:
13
+ model = YOLO(model_path)
14
+ except Exception as e:
15
+ print(f"Error loading model: {e}. Ensure you have a valid YOLOv8 model path.")
16
+ return []
17
+
18
+ # 2. Open the video file
19
+ cap = cv2.VideoCapture(video_path)
20
+ if not cap.isOpened():
21
+ print(f"Error: Could not open video file {video_path}")
22
+ return []
23
+
24
+ # Get video properties
25
+ fps = cap.get(cv2.CAP_PROP_FPS)
26
+ frame_interval = int(fps / frames_per_sec) # Calculate interval to sample frames
27
+ frame_count = 0
28
+ analysis_results = []
29
+
30
+ #
31
+
32
+ print(f"Video FPS: {fps}, Analyzing every {frame_interval} frames...")
33
+
34
+ while cap.isOpened():
35
+ # Read the next frame
36
+ ret, frame = cap.read()
37
+
38
+ if not ret:
39
+ break
40
+
41
+ # Check if the current frame is a sample frame
42
+ if frame_count % frame_interval == 0:
43
+ timestamp_sec = frame_count / fps
44
+
45
+ # 3. Run detection on the frame
46
+ results = model(frame, verbose=False) # Run detection
47
+
48
+ # 4. Process and structure results
49
+ detections = []
50
+ for r in results:
51
+ # r.boxes.data is a tensor with [x1, y1, x2, y2, confidence, class_id]
52
+ for box in r.boxes.data.tolist():
53
+ x1, y1, x2, y2, conf, cls = box
54
+ label = model.names[int(cls)]
55
+
56
+ # Store only the necessary info
57
+ detections.append({
58
+ 'label': label,
59
+ 'confidence': round(conf, 2),
60
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)] # Bounding Box
61
+ })
62
+
63
+ # Store structured result
64
+ analysis_results.append({
65
+ 'video_id': os.path.basename(video_path),
66
+ 'frame_id': frame_count,
67
+ 'timestamp_sec': round(timestamp_sec, 2),
68
+ 'detections': detections
69
+ })
70
+
71
+ frame_count += 1
72
+
73
+ # 5. Release video object
74
+ cap.release()
75
+ print(f"Analysis complete. Total frames analyzed: {len(analysis_results)}")
76
+ return analysis_results
77
+
78
+ # Example Usage:
79
+ if __name__ == '__main__':
80
+ # NOTE: You'll need a sample video file in the same directory (e.g., 'construction.mp4')
81
+ # and a trained PPE model file. For a quick test, you can use the default 'yolov8n.pt'
82
+ # which detects general objects (like 'person') until you fine-tune a PPE model.
83
+ if not os.path.exists('construction.mp4'):
84
+ print("Please place a video file named 'construction.mp4' in the current directory.")
85
+ else:
86
+ results = analyze_video_for_ppe('construction.mp4', frames_per_sec=0.5)
87
+
88
+ # Save raw results (optional, but good for debugging)
89
+ with open('raw_analysis.json', 'w') as f:
90
+ json.dump(results, f, indent=4)
91
+
92
+ print(f"Raw analysis saved to raw_analysis.json. {len(results)} records created.")
pyproject.toml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "rag"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "chromadb>=1.3.6",
9
+ "faiss-cpu>=1.13.1",
10
+ "google-genai>=1.55.0",
11
+ "gradio>=6.1.0",
12
+ "langchain>=1.1.3",
13
+ "langchain-chroma>=1.0.0",
14
+ "langchain-community>=0.4.1",
15
+ "langchain-core>=1.1.3",
16
+ "langchain-google-genai>=4.0.0",
17
+ "langchain-huggingface>=1.1.0",
18
+ "numpy>=2.3.5",
19
+ "opencv-python>=4.11.0.86",
20
+ "python-dotenv>=1.2.1",
21
+ "sentence-transformers>=5.2.0",
22
+ "streamlit>=1.52.1",
23
+ "ultralytics>=8.3.236",
24
+ ]
req.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python
2
+ ultralytics
3
+ numpy
4
+ chromadb
5
+ sentence-transformers
6
+ google-genai
7
+ langchain
8
+ langchain-google-genai
9
+ langchain-core
10
+ langchain-community
11
+ langchain-huggingface
12
+ langchain-chroma
uv.lock ADDED
The diff for this file is too large to render. See raw diff