#!/usr/bin/env python3 import os import subprocess import sys import traceback import gradio as gr from openai import OpenAI from typing import Dict import cv2 import numpy as np import pandas as pd from sklearn.metrics.pairwise import cosine_similarity from video_classification_frozen.eval import init_opt, init_model from src.datasets.data_manager import init_data from src.utils.distributed import init_distributed import torch import torch.nn.functional as F import tempfile from src.models.attentive_pooler import AttentiveClassifier from video_classification_frozen.utils import make_transforms, FrameAggregation, ClipAggregation import logging import threading from datetime import datetime from typing import Any, Tuple import os from dotenv import load_dotenv import time from langchain.chains.llm import LLMChain from langchain_community.chat_message_histories import ChatMessageHistory from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline from langchain_community.vectorstores import Neo4jVector from langchain_core.callbacks import BaseCallbackHandler from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableBranch from langchain_openai import ChatOpenAI from langchain_text_splitters import TokenTextSplitter from kg_sys.llm import get_llm from kg_sys.shared.constants import * from kg_sys.shared.common_fn import create_graph_database_connection, load_embedding_model from langchain_core.messages import HumanMessage, AIMessage from langchain_community.chat_message_histories import Neo4jChatMessageHistory from langchain.chains import GraphCypherQAChain import json from kg_sys.shared.constants import CHAT_SYSTEM_TEMPLATE, CHAT_TOKEN_CUT_OFF # Load environment variables load_dotenv() # Constants EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") NEO4J_URI = "neo4j+s://02616835.databases.neo4j.io" USERNAME = "neo4j" PASSWORD = "KBbDcrVF5pQklhEzfy8wXP9a5wIV_IfV7Qdr64SvTkQ" MODEL = os.getenv("VITE_LLM_MODELS") EMBEDDING_FUNCTION, _ = load_embedding_model(EMBEDDING_MODEL) # Helper functions from your code def setup_neo4j_connection(): return create_graph_database_connection(uri=NEO4J_URI, userName=USERNAME, password=PASSWORD, database='neo4j') class SessionChatHistory: history_dict = {} @classmethod def get_chat_history(cls, session_id): """Retrieve or create chat message history for a given session ID.""" if session_id not in cls.history_dict: logging.info(f"Creating new ChatMessageHistory Local for session ID: {session_id}") cls.history_dict[session_id] = ChatMessageHistory() else: logging.info(f"Retrieved existing ChatMessageHistory Local for session ID: {session_id}") return cls.history_dict[session_id] class CustomCallback(BaseCallbackHandler): def __init__(self): self.transformed_question = None def on_llm_end( self, response, **kwargs: Any ) -> None: logging.info("question transformed") self.transformed_question = response.generations[0][0].text.strip() def get_history_by_session_id(session_id): try: return SessionChatHistory.get_chat_history(session_id) except Exception as e: logging.error(f"Failed to get history for session ID '{session_id}': {e}") raise def get_total_tokens(ai_response, llm): try: if isinstance(llm, (ChatOpenAI)): total_tokens = ai_response.response_metadata.get('token_usage', {}).get('total_tokens', 0) else: logging.warning(f"Unrecognized language model: {type(llm)}. Returning 0 tokens.") total_tokens = 0 except Exception as e: logging.error(f"Error retrieving total tokens: {e}") total_tokens = 0 return total_tokens def clear_chat_history(graph, session_id, local=False): try: if not local: history = Neo4jChatMessageHistory( graph=graph, session_id=session_id ) else: history = get_history_by_session_id(session_id) history.clear() return { "session_id": session_id, "message": "The chat history has been cleared.", "user": "chatbot" } except Exception as e: logging.error(f"Error clearing chat history for session {session_id}: {e}") return { "session_id": session_id, "message": "Failed to clear chat history.", "user": "chatbot" } def get_sources_and_chunks(sources_used, docs): chunkdetails_list = [] sources_used_set = set(sources_used) seen_ids_and_scores = set() for doc in docs: try: source = doc.metadata.get("source") chunkdetails = doc.metadata.get("chunkdetails", []) if source in sources_used_set: for chunkdetail in chunkdetails: id = chunkdetail.get("id") score = round(chunkdetail.get("score", 0), 4) id_and_score = (id, score) if id_and_score not in seen_ids_and_scores: seen_ids_and_scores.add(id_and_score) chunkdetails_list.append({**chunkdetail, "score": score}) except Exception as e: logging.error(f"Error processing document: {e}") result = { 'sources': sources_used, 'chunkdetails': chunkdetails_list, } return result def get_rag_chain(llm, system_template=CHAT_SYSTEM_TEMPLATE): try: question_answering_prompt = ChatPromptTemplate.from_messages( [ ("system", system_template), MessagesPlaceholder(variable_name="messages"), ( "human", "User question: {input}" ), ] ) question_answering_chain = question_answering_prompt | llm return question_answering_chain except Exception as e: logging.error(f"Error creating RAG chain: {e}") raise def format_documents(documents, model): prompt_token_cutoff = 4 for model_names, value in CHAT_TOKEN_CUT_OFF.items(): if model in model_names: prompt_token_cutoff = value break sorted_documents = sorted(documents, key=lambda doc: doc.state.get("query_similarity_score", 0), reverse=True) sorted_documents = sorted_documents[:prompt_token_cutoff] formatted_docs = list() sources = set() entities = dict() global_communities = list() for doc in sorted_documents: try: source = doc.metadata.get('source', "unknown") sources.add(source) entities = doc.metadata['entities'] if 'entities' in doc.metadata.keys() else entities global_communities = doc.metadata[ "communitydetails"] if 'communitydetails' in doc.metadata.keys() else global_communities formatted_doc = ( "Document start\n" f"This Document belongs to the source {source}\n" f"Content: {doc.page_content}\n" "Document end\n" ) formatted_docs.append(formatted_doc) except Exception as e: logging.error(f"Error formatting document: {e}") return "\n\n".join(formatted_docs), sources, entities, global_communities def process_documents(docs, question, messages, llm, model, chat_mode_settings): start_time = time.time() try: formatted_docs, sources, entitydetails, communities = format_documents(docs, model) rag_chain = get_rag_chain(llm=llm) ai_response = rag_chain.invoke({ "messages": messages[:-1], "context": formatted_docs, "input": question }) result = {'sources': list(), 'nodedetails': dict(), 'entities': dict()} node_details = {"chunkdetails": list(), "entitydetails": list(), "communitydetails": list()} entities = {'entityids': list(), "relationshipids": list()} if chat_mode_settings["mode"] == CHAT_ENTITY_VECTOR_MODE: node_details["entitydetails"] = entitydetails elif chat_mode_settings["mode"] == CHAT_GLOBAL_VECTOR_FULLTEXT_MODE: node_details["communitydetails"] = communities else: sources_and_chunks = get_sources_and_chunks(sources, docs) result['sources'] = sources_and_chunks['sources'] node_details["chunkdetails"] = sources_and_chunks["chunkdetails"] entities.update(entitydetails) result["nodedetails"] = node_details result["entities"] = entities content = ai_response.content total_tokens = get_total_tokens(ai_response, llm) predict_time = time.time() - start_time logging.info(f"Final response predicted in {predict_time:.2f} seconds") except Exception as e: logging.error(f"Error processing documents: {e}") raise return content, result, total_tokens, formatted_docs def retrieve_documents(doc_retriever, messages): start_time = time.time() try: handler = CustomCallback() docs = doc_retriever.invoke({"messages": messages}, {"callbacks": [handler]}) transformed_question = handler.transformed_question if transformed_question: logging.info(f"Transformed question : {transformed_question}") doc_retrieval_time = time.time() - start_time logging.info(f"Documents retrieved in {doc_retrieval_time:.2f} seconds") except Exception as e: error_message = f"Error retrieving documents: {str(e)}" logging.error(error_message) raise RuntimeError(error_message) return docs, transformed_question def create_document_retriever_chain(llm, retriever): try: logging.info("Starting to create document retriever chain") query_transform_prompt = ChatPromptTemplate.from_messages( [ ("system", QUESTION_TRANSFORM_TEMPLATE), MessagesPlaceholder(variable_name="messages") ] ) output_parser = StrOutputParser() splitter = TokenTextSplitter(chunk_size=CHAT_DOC_SPLIT_SIZE, chunk_overlap=0) embeddings_filter = EmbeddingsFilter( embeddings=EMBEDDING_FUNCTION, similarity_threshold=CHAT_EMBEDDING_FILTER_SCORE_THRESHOLD ) pipeline_compressor = DocumentCompressorPipeline( transformers=[splitter, embeddings_filter] ) compression_retriever = ContextualCompressionRetriever( base_compressor=pipeline_compressor, base_retriever=retriever ) query_transforming_retriever_chain = RunnableBranch( ( lambda x: len(x.get("messages", [])) == 1, (lambda x: x["messages"][-1].content) | compression_retriever, ), query_transform_prompt | llm | output_parser | compression_retriever, ).with_config(run_name="chat_retriever_chain") logging.info("Successfully created document retriever chain") return query_transforming_retriever_chain except Exception as e: logging.error(f"Error creating document retriever chain: {e}", exc_info=True) raise def initialize_neo4j_vector(graph, chat_mode_settings): try: retrieval_query = chat_mode_settings.get("retrieval_query") index_name = chat_mode_settings.get("index_name") keyword_index = chat_mode_settings.get("keyword_index", "") node_label = chat_mode_settings.get("node_label") embedding_node_property = chat_mode_settings.get("embedding_node_property") text_node_properties = chat_mode_settings.get("text_node_properties") if not retrieval_query or not index_name: raise ValueError("Required settings 'retrieval_query' or 'index_name' are missing.") if keyword_index: neo_db = Neo4jVector.from_existing_graph( embedding=EMBEDDING_FUNCTION, index_name=index_name, retrieval_query=retrieval_query, graph=graph, search_type="hybrid", node_label=node_label, embedding_node_property=embedding_node_property, text_node_properties=text_node_properties, keyword_index_name=keyword_index ) logging.info( f"Successfully retrieved Neo4jVector Fulltext index '{index_name}' and keyword index '{keyword_index}'") else: neo_db = Neo4jVector.from_existing_graph( embedding=EMBEDDING_FUNCTION, index_name=index_name, retrieval_query=retrieval_query, graph=graph, node_label=node_label, embedding_node_property=embedding_node_property, text_node_properties=text_node_properties ) logging.info(f"Successfully retrieved Neo4jVector index '{index_name}'") except Exception as e: index_name = chat_mode_settings.get("index_name") logging.error(f"Error retrieving Neo4jVector index {index_name} : {e}") raise return neo_db def create_retriever(neo_db, document_names, chat_mode_settings, search_k, score_threshold): if document_names and chat_mode_settings["document_filter"]: retriever = neo_db.as_retriever( search_type="similarity_score_threshold", search_kwargs={ 'k': search_k, 'score_threshold': score_threshold, 'filter': {'fileName': {'$in': document_names}} } ) logging.info( f"Successfully created retriever with search_k={search_k}, score_threshold={score_threshold} for documents {document_names}") else: retriever = neo_db.as_retriever( search_type="similarity_score_threshold", search_kwargs={'k': search_k, 'score_threshold': score_threshold} ) logging.info(f"Successfully created retriever with search_k={search_k}, score_threshold={score_threshold}") return retriever def get_neo4j_retriever(graph, document_names, chat_mode_settings, score_threshold=CHAT_SEARCH_KWARG_SCORE_THRESHOLD): try: neo_db = initialize_neo4j_vector(graph, chat_mode_settings) # document_names= list(map(str.strip, json.loads(document_names))) search_k = chat_mode_settings["top_k"] retriever = create_retriever(neo_db, document_names, chat_mode_settings, search_k, score_threshold) return retriever except Exception as e: index_name = chat_mode_settings.get("index_name") logging.error(f"Error retrieving Neo4jVector index {index_name} or creating retriever: {e}") raise Exception( f"An error occurred while retrieving the Neo4jVector index or creating the retriever. Please drop and create a new vector index '{index_name}': {e}") from e def setup_chat(model, graph, document_names, chat_mode_settings): start_time = time.time() try: if model == "diffbot": model = os.getenv('DEFAULT_DIFFBOT_CHAT_MODEL') llm, model_name = get_llm(model=model) logging.info(f"Model called in chat: {model} (version: {model_name})") retriever = get_neo4j_retriever(graph=graph, chat_mode_settings=chat_mode_settings, document_names=document_names) doc_retriever = create_document_retriever_chain(llm, retriever) chat_setup_time = time.time() - start_time logging.info(f"Chat setup completed in {chat_setup_time:.2f} seconds") except Exception as e: logging.error(f"Error during chat setup: {e}", exc_info=True) raise return llm, doc_retriever, model_name def process_chat_response(messages, history, question, model, graph, document_names, chat_mode_settings): try: llm, doc_retriever, model_version = setup_chat(model, graph, document_names, chat_mode_settings) docs, transformed_question = retrieve_documents(doc_retriever, messages) if docs: content, result, total_tokens, formatted_docs = process_documents(docs, question, messages, llm, model, chat_mode_settings) else: content = "I couldn't find any relevant documents to answer your question." result = {"sources": list(), "nodedetails": list(), "entities": list()} total_tokens = 0 formatted_docs = "" ai_response = AIMessage(content=content) messages.append(ai_response) # summarization_thread = threading.Thread(target=summarize_and_log, args=(history, messages, llm)) # summarization_thread.start() logging.info("Summarization thread started.") # summarize_and_log(history, messages, llm) metric_details = {"question": question, "contexts": formatted_docs, "answer": content} return { "session_id": "", "message": content, "info": { # "metrics" : metrics, "sources": result["sources"], "model": model_version, "nodedetails": result["nodedetails"], "total_tokens": total_tokens, "response_time": 0, "mode": chat_mode_settings["mode"], "entities": result["entities"], "metric_details": metric_details, }, "user": "chatbot" } except Exception as e: logging.exception(f"Error processing chat response at {datetime.now()}: {str(e)}") return { "session_id": "", "message": "Something went wrong", "info": { "metrics": [], "sources": [], "nodedetails": [], "total_tokens": 0, "response_time": 0, "error": f"{type(e).__name__}: {str(e)}", "mode": chat_mode_settings["mode"], "entities": [], "metric_details": {}, }, "user": "chatbot" } def summarize_and_log(history, stored_messages, llm): logging.info("Starting summarization in a separate thread.") if not stored_messages: logging.info("No messages to summarize.") return False try: start_time = time.time() summarization_prompt = ChatPromptTemplate.from_messages( [ MessagesPlaceholder(variable_name="chat_history"), ( "human", "Summarize the above chat messages into a concise message, focusing on key points and relevant details that could be useful for future conversations. Exclude all introductions and extraneous information." ), ] ) summarization_chain = summarization_prompt | llm summary_message = summarization_chain.invoke({"chat_history": stored_messages}) with threading.Lock(): history.clear() history.add_user_message("Our current conversation summary till now") history.add_message(summary_message) history_summarized_time = time.time() - start_time logging.info(f"Chat History summarized in {history_summarized_time:.2f} seconds") return True except Exception as e: logging.error(f"An error occurred while summarizing messages: {e}", exc_info=True) return False def create_graph_chain(model, graph): try: logging.info(f"Graph QA Chain using LLM model: {model}") cypher_llm, model_name = get_llm(model) qa_llm, model_name = get_llm(model) graph_chain = GraphCypherQAChain.from_llm( cypher_llm=cypher_llm, qa_llm=qa_llm, validate_cypher=True, graph=graph, # verbose=True, allow_dangerous_requests=True, return_intermediate_steps=True, top_k=3 ) logging.info("GraphCypherQAChain instance created successfully.") return graph_chain, qa_llm, model_name except Exception as e: logging.error(f"An error occurred while creating the GraphCypherQAChain instance. : {e}") def get_graph_response(graph_chain, question): try: cypher_res = graph_chain.invoke({"query": question}) response = cypher_res.get("result") cypher_query = "" context = [] for step in cypher_res.get("intermediate_steps", []): if "query" in step: cypher_string = step["query"] cypher_query = cypher_string.replace("cypher\n", "").replace("\n", " ").strip() elif "context" in step: context = step["context"] return { "response": response, "cypher_query": cypher_query, "context": context } except Exception as e: logging.error(f"An error occurred while getting the graph response : {e}") def process_graph_response(model, graph, question, messages, history): try: graph_chain, qa_llm, model_version = create_graph_chain(model, graph) graph_response = get_graph_response(graph_chain, question) ai_response_content = graph_response.get("response", "Something went wrong") ai_response = AIMessage(content=ai_response_content) messages.append(ai_response) # summarize_and_log(history, messages, qa_llm) # summarization_thread = threading.Thread(target=summarize_and_log, args=(history, messages, qa_llm)) # summarization_thread.start() logging.info("Summarization thread started.") metric_details = {"question": question, "contexts": graph_response.get("context", ""), "answer": ai_response_content} result = { "session_id": "", "message": ai_response_content, "info": { "model": model_version, "cypher_query": graph_response.get("cypher_query", ""), "context": graph_response.get("context", ""), "mode": "graph", "response_time": 0, "metric_details": metric_details, }, "user": "chatbot" } return result except Exception as e: logging.exception(f"Error processing graph response at {datetime.now()}: {str(e)}") return { "session_id": "", "message": "Something went wrong", "info": { "model": model_version, "cypher_query": "", "context": "", "mode": "graph", "response_time": 0, "error": f"{type(e).__name__}: {str(e)}" }, "user": "chatbot" } def create_neo4j_chat_message_history(graph, session_id, write_access=True): """ Creates and returns a Neo4jChatMessageHistory instance. """ try: if write_access: history = Neo4jChatMessageHistory( graph=graph, session_id=session_id ) return history history = get_history_by_session_id(session_id) return history except Exception as e: logging.error(f"Error creating Neo4jChatMessageHistory: {e}") raise def get_chat_mode_settings(mode, settings_map=CHAT_MODE_CONFIG_MAP): default_settings = settings_map[CHAT_DEFAULT_MODE] try: chat_mode_settings = settings_map.get(mode, default_settings) chat_mode_settings["mode"] = mode logging.info(f"Chat mode settings: {chat_mode_settings}") except Exception as e: logging.error(f"Unexpected error: {e}", exc_info=True) raise return chat_mode_settings def QA_RAG(question, session_id, mode, write_access=True): logging.info(f"Chat Mode: {mode}") graph = setup_neo4j_connection() model = MODEL document_names = '["AM400_User_guide_1_12.md", "AM400_User_guide_13_24.md", "AM400_User_guide_25_36.md", ' \ '"AM400_User_guide_37_56.md", "AM400_User_guide_57_63.md", "AM400_User_guide_64_90.md",' \ '"AM400_User_guide_90_104.md", "AM400_User_guide_105_114.md", "AM400_User_guide_115_140.md",' \ '"AM400_User_guide_141_153.md", "AM400_User_guide_154_163.md", "AM400_User_guide_164_176.md",' \ '"AM400_User_guide_177_178.md", "AM400_User_guide_179_211.md", "AM400_User_guide_212_228.md",' \ '"AM400_User_guide_229_240.md"]' history = create_neo4j_chat_message_history(graph, session_id, write_access) messages = history.messages user_question = HumanMessage(content=question) messages.append(user_question) if mode == CHAT_GRAPH_MODE: result = process_graph_response(model, graph, question, messages, history) else: chat_mode_settings = get_chat_mode_settings(mode=mode) document_names = list(map(str.strip, json.loads(document_names))) result = process_chat_response(messages, history, question, model, graph, document_names, chat_mode_settings) result["session_id"] = session_id # print('result', result) return result['message'] print("Starting script...") def validate_video(video_path): """ Validate video file and check if frames can be extracted. Returns tuple of (is_valid, message, num_frames) """ try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return False, "Failed to open video file", 0 # Get video properties fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) print(f"Video properties: {width}x{height} @ {fps}fps, {total_frames} frames") if total_frames < 1: return False, "Video has no frames", 0 if fps <= 0: return False, "Invalid video FPS", 0 # Try to read first frame ret, frame = cap.read() if not ret or frame is None: return False, "Failed to read first frame", 0 # Try to read last frame cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1) ret, frame = cap.read() if not ret or frame is None: return False, "Failed to read last frame", 0 cap.release() return True, "Video validated successfully", total_frames except Exception as e: return False, f"Error validating video: {str(e)}", 0 def make_single_video_dataloader(video_path, config): """Create a dataloader for a single video using CSV file""" print(f"Creating data loader for video: {video_path}") try: # Validate video first cap = cv2.VideoCapture(video_path) if not cap.isOpened(): print("Failed to open video file") return None # Get video properties total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() if total_frames < 1: print("Video has no frames") return None # Create temporary directory for CSV # with tempfile.TemporaryDirectory() as temp_dir: # Create CSV file listing the video # csv_path = os.path.join(temp_dir, 'video_list.csv') csv_path = 'video_list.csv' with open(csv_path, 'w') as f: # Write single line with path and label separated by space f.write(f"{video_path} 0\n") print(f"Created CSV at: {csv_path}") data = pd.read_csv(csv_path, header=None, delimiter=" ") samples = list(data.values[:, 0]) print('samples', samples) labels = list(data.values[:, 1]) print('labels', labels) transform = make_transforms( training=False, num_views_per_clip=config['num_views_per_segment'], random_horizontal_flip=False, random_resize_aspect_ratio=(0.75, 4 / 3), random_resize_scale=(0.08, 1.0), reprob=0.25, auto_augment=True, motion_shift=False, crop_size=config['resolution'], ) print(f"Creating data loader with CSV: {csv_path}") try: data_loader, info = init_data( data='VideoDataset', root_path=[csv_path], transform=transform, batch_size=1, world_size=1, rank=0, clip_len=config['frames_per_clip'], frame_sample_rate=config['frame_step'], num_clips=config['num_segments'], allow_clip_overlap=True, num_workers=1, copy_data=False, drop_last=False ) print(f"Data loader info: {info}") except Exception as e: print("\n=== Data Loader Creation Error ===") print(f"Error type: {type(e).__name__}") print(f"Error message: {str(e)}") print("\nFull traceback:") traceback.print_exc(file=sys.stdout) return None return data_loader except Exception as e: print(f"Error creating data loader: {str(e)}") return None class VideoClassifier: def __init__(self, config): print("Initializing VideoClassifier...") self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {self.device}") self.config = config self.encoder, self.classifier = self.load_models() self.transform = make_transforms( training=False, num_views_per_clip=config['num_views_per_segment'], crop_size=config['resolution'] ) def load_checkpoint(self, device, r_path, classifier, opt, scaler): """ Load checkpoint for classifier, optimizer and scaler. Args: device: torch device to load the model onto r_path: path to checkpoint file classifier: classifier model to load weights into opt: optimizer to load state into scaler: gradient scaler for mixed precision training (can be None) Returns: tuple: (classifier, optimizer, scaler, epoch) """ try: print(f'Loading checkpoint from: {r_path}') checkpoint = torch.load(r_path, map_location=torch.device('cpu')) # Get epoch from checkpoint epoch = checkpoint.get('epoch', 0) print(f'Checkpoint from epoch: {epoch}') # Clean and load classifier state dict if 'classifier' in checkpoint: state_dict = checkpoint['classifier'] # Remove 'module.' prefix if present new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v try: msg = classifier.load_state_dict(new_state_dict) print(f'Loaded classifier state with message: {msg}') except Exception as e: print(f'Error loading classifier state dict: {str(e)}') raise else: print('No classifier state found in checkpoint') raise KeyError('No classifier state in checkpoint') # Load optimizer state if present if 'opt' in checkpoint and opt is not None: try: opt.load_state_dict(checkpoint['opt']) print('Loaded optimizer state') except Exception as e: print(f'Error loading optimizer state: {str(e)}') raise else: print('No optimizer state found in checkpoint') # Load scaler state if present and scaler is provided if scaler is not None and 'scaler' in checkpoint: try: scaler.load_state_dict(checkpoint['scaler']) print('Loaded scaler state') except Exception as e: print(f'Error loading scaler state: {str(e)}') raise elif scaler is not None: print('No scaler state found in checkpoint') # Clean up memory del checkpoint torch.cuda.empty_cache() return classifier, opt, scaler, epoch except FileNotFoundError: print(f'Checkpoint file not found: {r_path}') return classifier, opt, scaler, 0 except Exception as e: print(f'Error loading checkpoint: {str(e)}') print(f'Traceback: {traceback.format_exc()}') return classifier, opt, scaler, 0 def load_models(self): print("Loading models...") # Initialize encoder encoder = init_model( crop_size=self.config['resolution'], device=self.device, pretrained=self.config['pretrained_path'], model_name=self.config['model_name'], patch_size=self.config['patch_size'], tubelet_size=self.config['tubelet_size'], frames_per_clip=self.config['frames_per_clip'], uniform_power=self.config['uniform_power'], checkpoint_key=self.config['checkpoint_key'], use_SiLU=self.config['use_silu'], tight_SiLU=self.config['tight_silu'], use_sdpa=self.config['use_sdpa']) if self.config['frames_per_clip'] == 1: # Process each frame independently and aggregate encoder = FrameAggregation(encoder).to(self.device) else: # Process each video clip independently and aggregate encoder = ClipAggregation( encoder, tubelet_size=self.config['tubelet_size'], attend_across_segments=self.config['attend_across_segments'] ).to(self.device) print("Loading pretrained weights...") # Load pretrained weights checkpoint = torch.load(self.config['pretrained_path'], map_location=self.device) encoder.load_state_dict(checkpoint[self.config['checkpoint_key']], strict=False) encoder.eval() print("Initializing classifier...") # Initialize and load classifier classifier = AttentiveClassifier( embed_dim=encoder.embed_dim, num_heads=encoder.num_heads, depth=1, num_classes=self.config['num_classes'] ).to(self.device) optimizer, scaler, scheduler, wd_scheduler = init_opt( classifier=classifier, wd=0.01, start_lr=0.001, ref_lr=0.001, final_lr=0.0, iterations_per_epoch=1, warmup=0., num_epochs=20, use_bfloat16=True) # classifier = DistributedDataParallel(classifier, static_graph=True) classifier, optimizer, scaler, epoch = self.load_checkpoint( device=self.device, r_path=self.config['classifier_path'], classifier=classifier, opt=optimizer, scaler=scaler ) # classifier_checkpoint = torch.load(self.config['classifier_path'], map_location=self.device) # classifier.load_state_dict(classifier_checkpoint['classifier']) classifier.eval() return encoder, classifier def predict(self, video_path): """ Predict the class of a video. """ print(f"Starting prediction for video: {video_path}") try: # Check if video file exists if not os.path.exists(video_path): print(f"Error: Video file not found: {video_path}") return "Error: Video file not found", 0.0 # Check video file size file_size = os.path.getsize(video_path) if file_size == 0: print("Error: Video file is empty") return "Error: Empty video file", 0.0 print(f"Video file size: {file_size} bytes") # Initialize data loader data_loader = make_single_video_dataloader(video_path, self.config) if data_loader is None: return "Error: Failed to create data loader", 0.0 # Process the video with torch.no_grad(): try: # Get first batch batch = next(iter(data_loader)) # Prepare clips clips = [ [dij.to(self.device, non_blocking=True) for dij in di] for di in batch[0] ] clip_indices = [d.to(self.device, non_blocking=True) for d in batch[2]] print("Running encoder forward pass...") encoder_outputs = self.encoder(clips, clip_indices) print("Running classifier forward pass...") if self.config['attend_across_segments']: outputs = [self.classifier(o) for o in encoder_outputs] probs = sum([F.softmax(o, dim=1) for o in outputs]) / len(outputs) else: outputs = [[self.classifier(ost) for ost in os] for os in encoder_outputs] probs = sum([sum([F.softmax(ost, dim=1) for ost in os]) for os in outputs]) / len(outputs) / len(outputs[0]) # Get prediction and confidence pred_class = probs.argmax(dim=1).item() confidence = probs.max(dim=1).values.item() class_name = self.config['class_names'][pred_class] print(f'Predicted class: {class_name} (index: {pred_class})') print(f'Confidence: {confidence:.4f}') return class_name, confidence except StopIteration: print("Error: No data in data loader") return "Error: Failed to load video frames", 0.0 except Exception as e: print(f"Error during prediction: {str(e)}") return f"Error during prediction: {str(e)}", 0.0 except Exception as e: print(f"Error in prediction pipeline: {str(e)}") return f"Error in prediction pipeline: {str(e)}", 0.0 def process_video(classifier, video_path): print("Processing video...") if video_path is None: return "No video uploaded", 0.0 mp4_path = None # Ensure mp4_path is always defined try: if isinstance(video_path, str): print(f"Processing video from path: {video_path}") # Check if the file is in .webm format if video_path.endswith('.webm'): print("Converting .webm to .mp4 using ffmpeg...") with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file: mp4_path = temp_file.name # Use ffmpeg to convert the file command = ['ffmpeg', '-y', '-i', video_path, '-c:v', 'libx264', '-preset', 'fast', '-crf', '23', '-c:a', 'aac', mp4_path] result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) if result.returncode != 0: print(f"FFmpeg error: {result.stderr.decode()}") return "Error during video conversion", 0.0 print(f"Converted video saved to: {mp4_path}") video_path = mp4_path # Update video_path to use the converted file # Process the video file try: return classifier.predict(video_path) finally: # Clean up the temporary converted file if it was created if mp4_path and os.path.exists(mp4_path): os.unlink(mp4_path) else: print(f"Error: Unexpected video input type: {type(video_path)}") return "Error: Invalid video format", 0.0 except Exception as e: print(f"Error processing video: {str(e)}") return f"Error: {str(e)}", 0.0 def process_highlights_and_images(qa_reuslt_pure): if not qa_reuslt_pure: return "", [], {} response = qa_reuslt_pure # Process summary and get cleaned lines cleaned_lines = [] print('qa_reuslt_pure', qa_reuslt_pure) for line in qa_reuslt_pure.strip().split('\n'): cleaned_line = line.replace("**", "") cleaned_lines.append(cleaned_line) # Setup Neo4j connection graph = setup_neo4j_connection() try: # Process each line for images captions = [] images = [] similarities = [] for line in cleaned_lines: line_embedding = EMBEDDING_FUNCTION.embed_query(line) query = """ WITH $line_embedding AS line_embedding MATCH (f:Figure) WHERE vector.similarity.cosine(f.embedding, line_embedding) > $similarity_threshold RETURN f, vector.similarity.cosine(f.embedding, line_embedding) AS similarity ORDER BY similarity DESC LIMIT 1 """ result = graph.query(query, { "line_embedding": line_embedding, "similarity_threshold": 0.76 }) for record in result: node = record["f"] similarity = record["similarity"] if 'img_ref' in node: images.append(node['img_ref']) captions.append(node['description']) similarities.append(similarity) # Generate highlighted response highlighted_response = response if captions: response_phrases = response.split('. ') response_embeddings = [EMBEDDING_FUNCTION.embed_query(phrase) for phrase in response_phrases] caption_embeddings = [EMBEDDING_FUNCTION.embed_query(caption) for caption in captions] similarity_matrix = cosine_similarity(response_embeddings, caption_embeddings) highlighted_response = "" for i, phrase in enumerate(response_phrases): max_similarity = np.max(similarity_matrix[i]) is_last_phrase = (i == len(response_phrases) - 1) if max_similarity >= 0.75: highlighted_response += f"{phrase}" else: highlighted_response += phrase if not is_last_phrase: highlighted_response += ". " # Prepare image information image_info = [ {"caption": caption, "similarity": f"{sim:.3f}"} for caption, sim in zip(captions, similarities) ] return highlighted_response, images, image_info finally: if hasattr(graph, 'close'): graph.close() class AudioAgent: def __init__(self, api_key: str): self.client = OpenAI(api_key=api_key) self.temp_dir = tempfile.mkdtemp() # Create a temporary directory for files print('temp dir', self.temp_dir) self.conversation_state = { 'current_stage': None, 'video_prediction': None, 'video_confidence': None, 'generated_question': None, 'pure_response': None, 'highlighted_response': None, 'gallery': None, 'image_infos': None, 'webcam_active': False } def _get_temp_path(self, filename: str) -> str: """Generate a temporary file path""" temp_path = os.path.join(self.temp_dir, filename) print('temp_path', temp_path) return temp_path def transcribe_audio(self, audio_path: str) -> str: """Transcribe audio using Whisper""" try: if not os.path.isfile(audio_path): print(f"Invalid audio path: {audio_path}") return "" with open(audio_path, "rb") as audio_file: transcript = self.client.audio.transcriptions.create( model="whisper-1", file=audio_file ) print('transcript text', transcript.text) return transcript.text except Exception as e: print(f"Error in transcription: {e}") return "" def text_to_speech(self, text: str, filename: str) -> str: """Convert text to speech using OpenAI's TTS""" try: output_path = self._get_temp_path(filename) print('output_path', output_path) response = self.client.audio.speech.create( model="tts-1", voice="alloy", input=text ) response.stream_to_file(output_path) return output_path except Exception as e: print(f"Error in TTS conversion: {e}") return "" def determine_action(self, instruction: str) -> str: """Determine what action the user wants to take""" messages = [ { "role": "system", "content": f"""Determine what action the user wants to take based on the current conversation state: Current state: {self.conversation_state['current_stage']} Video prediction: {self.conversation_state['video_prediction']} Possible actions: - analyze_video: If they want to analyze a video - ask_question: If they are asking their own question - answer_questions: If they want to answer questions - generate_highlights: If they want to process answers for highlights - unclear: If the request is unclear Return only one of these action names.""" }, {"role": "user", "content": str(instruction)} ] try: response = self.client.chat.completions.create( model="gpt-4-turbo-preview", messages=messages ) return response.choices[0].message.content.strip() except Exception as e: print(f"Error determining action: {e}") return "unclear" def extract_question(self, instruction: str) -> str: """Extract the actual question from user's voice input""" messages = [ { "role": "system", "content": "Extract the actual question from the user's input. Return only the question itself." }, {"role": "user", "content": str(instruction)} ] try: response = self.client.chat.completions.create( model="gpt-4-turbo-preview", messages=messages ) return response.choices[0].message.content.strip() except Exception as e: print(f"Error extracting question: {e}") return instruction def process_request(self, instruction: str, **kwargs) -> Dict[str, Any]: """Process user instruction based on conversation state""" try: action = self.determine_action(instruction) results = {} final_audio_path = None # if action == "take_video": # self.conversation_state['webcam_active'] = True # return { # "results": {"message": "Activating webcam..."}, # "action": "take_video", # "webcam_active": True # } if action == "analyze_video": video_input = kwargs.get('webcam_input') print(f"Video input received: {video_input}, type: {type(video_input)}") video_path = None # Handle different video input formats if isinstance(video_input, tuple): # Webcam recording format video_path = video_input[0] if video_input else None elif isinstance(video_input, str): # Direct file path video_path = video_input elif isinstance(video_input, dict) and 'name' in video_input: # Uploaded file format video_path = video_input['name'] print(f"Processed video path: {video_path}") if video_path and os.path.isfile(video_path): classifier = kwargs.get('classifier') if classifier: video_result = process_video(classifier, video_path) self.conversation_state['video_prediction'] = video_result[0] self.conversation_state['video_confidence'] = video_result[1] self.conversation_state['current_stage'] = 'video_analyzed' results.update({ 'video_prediction': video_result[0], 'video_confidence': video_result[1] }) final_audio_path = self.text_to_speech( f"Video analysis complete. Detected action: {video_result[0]} with confidence {video_result[1]:.2f}. Would you like to ask questions about this activity?", "response_video.mp3" ) elif action == "ask_question": if self.conversation_state['video_prediction']: # Extract the actual question from user's voice input custom_question = self.extract_question(instruction) self.conversation_state['generated_question'] = custom_question self.conversation_state['current_stage'] = 'questions_generated' results['generated_question'] = custom_question final_audio_path = self.text_to_speech( f"I understand your question: '{custom_question}'. Would you like me to answer it?", "response_custom_question.mp3" ) else: return {"error": "Please analyze a video first"} elif action == "answer_questions": if self.conversation_state['generated_question']: pure_response = QA_RAG( self.conversation_state['generated_question'], kwargs.get('session_id', 1), kwargs.get('mode', 'CHAT_VECTOR_GRAPH_MODE') ) self.conversation_state['pure_response'] = pure_response self.conversation_state['current_stage'] = 'questions_answered' results.update({ 'pure_response': pure_response }) final_audio_path = self.text_to_speech( f"Sure! I've answered the questions. {pure_response}. Would you like me to generate highlights and related images?", "response_qa.mp3" ) else: return {"error": "Please generate questions first"} elif action == "generate_highlights": if self.conversation_state['pure_response']: highlighted_response, gallery, image_infos = process_highlights_and_images( self.conversation_state['pure_response'] ) self.conversation_state.update({ 'highlighted_response': highlighted_response, 'gallery': gallery, 'image_infos': image_infos, 'current_stage': 'highlights_generated' }) results.update({ 'highlighted_response': highlighted_response, 'gallery': gallery, 'image_infos': image_infos }) final_audio_path = self.text_to_speech( "I've generated highlights and retrieved related images. Is there anything else you'd like to know?", "response_highlights.mp3" ) else: return {"error": "Please answer questions first"} else: suggestion = "" if not self.conversation_state['current_stage']: suggestion = "You can start by asking me to analyze a video." elif self.conversation_state['current_stage'] == 'video_analyzed': suggestion = "You can ask me to generate questions about the activity." elif self.conversation_state['current_stage'] == 'questions_generated': suggestion = "You can ask me to answer the questions." elif self.conversation_state['current_stage'] == 'questions_answered': suggestion = "You can ask me to generate highlights and images." return { "error": f"I'm not sure what you want me to do. {suggestion}" } return { "results": results, "audio_path": final_audio_path, "action": action, **results } except Exception as e: print(f"Error in process_request: {e}") return {"error": f"Error processing request: {e}"} def process_audio_input( agent, audio_path: str, webcam_input: str = None, classifier=None, mode: str = 'CHAT_VECTOR_GRAPH_MODE', session_id: int = 1 ) -> Tuple[Dict[str, Any], str]: """Main function to handle audio input and generate audio output""" # agent = AudioAgent(api_key=os.getenv("OPENAI_API_KEY")) # Transcribe audio to text transcribed_text = agent.transcribe_audio(audio_path) if not transcribed_text: return {"error": "Error in transcription"}, "" # Process the transcribed text through the sequential pipeline result = agent.process_request( transcribed_text, webcam_input=webcam_input, classifier=classifier, mode=mode, session_id=session_id ) if "error" in result: return result, "" return result, result.get("audio_path", "") def main(): print("Starting main function...") # from huggingface_hub import snapshot_download # snapshot_download(repo_id="IELTS8/pretrained_model") import os import urllib.request # Define file paths and URLs models = { "pretrained_model/vith16.pth.tar": "https://dl.fbaipublicfiles.com/jepa/vith16/vith16.pth.tar", "pretrained_model/v-jepa-224-classifier/jepa-latest.pth.tar": "https://huggingface.co/IELTS8/pretrained_model/resolve/main/v-jepa-224-classifier/jepa-latest.pth.tar" } # Function to ensure directory exists def ensure_directory_exists(file_path): directory = os.path.dirname(file_path) if not os.path.exists(directory): os.makedirs(directory) # Function to download file if it does not exist def download_file(file_path, url): if not os.path.exists(file_path): print(f"Downloading {file_path} from {url}...") ensure_directory_exists(file_path) urllib.request.urlretrieve(url, file_path) print(f"Downloaded: {file_path}") else: print(f"File already exists: {file_path}") def update_video(video_path): return video_path # Check and download models for file_path, url in models.items(): download_file(file_path, url) # from huggingface_hub import hf_hub_download # hf_hub_download(repo_id="IELTS8/pretrained_model") model_config = { 'model_name': 'vit_huge', 'resolution': 224, 'patch_size': 16, 'frames_per_clip': 16, 'tubelet_size': 2, 'uniform_power': True, 'use_sdpa': True, 'use_silu': False, 'tight_silu': False, 'num_classes': 11, 'pretrained_path': 'pretrained_model/vith16.pth.tar', 'classifier_path': 'pretrained_model/v-jepa-224-classifier/jepa-latest.pth.tar', 'checkpoint_key': 'target_encoder', 'num_segments': 2, 'num_views_per_segment': 3, 'frame_step': 3, 'attend_across_segments': True, 'class_names': [ 'clean the build plate', 'clean the recoater', 'insert the build plate screw', 'measure the build plate thickness', 'open the process chamber door', 'operate valves of the large safe change filter', 'place build plate inside the process chamber', 'place recoater into the process chamber', 'press reset button on the control panel', 'press stop button on the control panel', 'wipe laser window inside the process chamber' ] } world_size, rank = init_distributed() print(f'Initialized (rank/world-size) {rank}/{world_size}') print("Creating classifier...") classifier = VideoClassifier(model_config) # print("Setting up Gradio interface...") # with gr.Blocks(theme=gr.themes.Soft()) as demo: # gr.Markdown("# 🏭 Sim2Real: Action Recognition (Synthetic Data) & Knowledge Guidance (Knowledge Graph)") # Create a custom theme custom_theme = gr.themes.Soft().set( body_background_fill="rgba(255, 255, 255, 0.3)", # Main background with 70% opacity block_background_fill="rgba(255, 255, 255, 0.3)", # Component blocks with 50% opacity panel_background_fill="rgba(255, 255, 255, 0.3)", # Panels with 30% opacity # Additional theme customizations for better visibility block_border_width="1px", block_border_color="rgba(128, 128, 128, 0.3)", block_radius="8px", # Text colors for better contrast against transparent background block_title_text_color="rgba(0, 0, 0, 0.8)", block_label_text_color="rgba(0, 0, 0, 0.7)", body_text_color="rgba(0, 0, 0, 0.8)" ) print("Setting up Gradio interface...") sample_videos = ["example_1.mp4", "example_2.mp4", "example_3.mp4", "example_4.mp4", "example_5.mp4"] with gr.Blocks(theme=custom_theme, css=""" .gradio-container { background: rgba(255, 255, 255, 0.7) !important; } .block { background: rgba(255, 255, 255, 0.5) !important; backdrop-filter: blur(10px); } .panel { background: rgba(255, 255, 255, 0.3) !important; } """) as demo: gr.Markdown("# 🏭 Sim2Real: Action Recognition (Synthetic Data) & Knowledge Guidance (Knowledge Graph)") # Add state for webcam control last_processed_time = gr.State(time.time()) is_processing = gr.State(False) with gr.Row(): with gr.Column(scale=2): gr.Markdown("### 🎬 Operation Recognition") with gr.Row(): with gr.Column(scale=3): video_input = gr.Video( label="Upload Video", sources=["upload", "webcam"], height=400, width=600 ) with gr.Column(scale=1): # Force examples to render in the right column with gr.Group(): gr.Examples( examples=sample_videos, inputs=video_input, examples_per_page=5, # Show all examples at once label="Sample Videos" ) with gr.Row(): audio_input = gr.Audio( sources=["microphone"], type="filepath", label="Speak your request", streaming=False ) with gr.Row(): text_output = gr.Textbox(label="Response") audio_output = gr.Audio( label="Audio Response", autoplay=True ) status = gr.Textbox(label="Status", value="Ready") # Hidden components for storing intermediate results video_prediction = gr.Textbox(visible=False) video_confidence = gr.Number(visible=False) generated_question = gr.Textbox(visible=False) with gr.Column(scale=3): gr.Markdown("### 🔍 Knowledge Guidance") with gr.Row(): mode = gr.Dropdown( choices=[ 'CHAT_VECTOR_GRAPH_MODE', 'CHAT_GRAPH_MODE', 'CHAT_VECTOR_MODE', 'CHAT_FULLTEXT_MODE', 'CHAT_VECTOR_GRAPH_FULLTEXT_MODE' ], label="Select Mode", value='CHAT_VECTOR_GRAPH_MODE', visible=False, ) session_id = gr.Number( label="Session ID", value=1, visible=False ) pure_response = gr.Markdown(visible=False) highlighted_response = gr.HTML(label="Highlighted Response") with gr.Row(): gallery = gr.Gallery(label="Related Images", height=150) image_infos = gr.JSON(label="Image Information", open=False) agent = AudioAgent(api_key=os.getenv("OPENAI_API_KEY")) def process_with_status(audio_path, last_time, is_proc, video_path, mode_value, session_id_value): # Debug logging print("=== Input Debug ===") print(f"Audio path: {audio_path}, type: {type(audio_path)}") print(f"Video path: {video_path}, type: {type(video_path)}") print(f"Mode value: {mode_value}") print(f"Session ID: {session_id_value}") # Video path handling with debug video_file = None if isinstance(video_path, tuple) and len(video_path) > 0: video_file = video_path[0] print(f"Extracted video file from tuple: {video_file}") elif isinstance(video_path, str): video_file = video_path print(f"Using video path directly: {video_file}") else: print(f"Unexpected video_path type: {type(video_path)}, value: {video_path}") current_time = time.time() # Check if enough time has passed (3 seconds minimum) if current_time - last_time < 4 or is_proc: return { status: "Processing previous request...", last_processed_time: last_time, is_processing: is_proc } # Set processing flag is_proc = True # Process the audio if audio_path is None: return { status: "No audio detected", last_processed_time: current_time, is_processing: False } result, audio_path = process_audio_input( agent=agent, audio_path=audio_path, webcam_input=video_path, classifier=classifier, mode=mode_value, session_id=session_id_value ) print("=== Output Debug ===") print(f"Result: {result}") print(f"Audio response path: {audio_path}") # # Handle webcam activation # if result.get("action") == "take_video": # return { # video_input: gr.update(interactive=True), # status: "Webcam activated. You can now record a video.", # text_output: "Webcam activated" # } if "error" in result: return result["error"], None, "", 0, [], "", "", [], {}, f"Error: {result['error']}" # status.update(value="Ready") return ( result, # text_output audio_path, # audio_output result.get("video_prediction", ""), # video_prediction result.get("video_confidence", 0), # video_confidence result.get("generated_question", []), # generated_question result.get("pure_response", ""), # pure_response result.get("highlighted_response", ""), # highlighted_response result.get("gallery", []), # gallery result.get("image_infos", {}), # image_infos, "Ready" # status ) audio_input.change( fn=process_with_status, inputs=[ audio_input, last_processed_time, is_processing, video_input, mode, session_id ], outputs=[ text_output, audio_output, video_prediction, video_confidence, generated_question, pure_response, highlighted_response, gallery, image_infos, status ] ) print("Launching interface...") return demo if __name__ == "__main__": print("Script started") demo = main() demo.launch(share=True)