| | |
| |
|
| | import os |
| | import subprocess |
| | import sys |
| | import traceback |
| | import gradio as gr |
| | from openai import OpenAI |
| |
|
| | from typing import Dict |
| |
|
| | import cv2 |
| | import numpy as np |
| | import pandas as pd |
| | from sklearn.metrics.pairwise import cosine_similarity |
| |
|
| | from video_classification_frozen.eval import init_opt, init_model |
| | from src.datasets.data_manager import init_data |
| | from src.utils.distributed import init_distributed |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import tempfile |
| |
|
| | from src.models.attentive_pooler import AttentiveClassifier |
| | from video_classification_frozen.utils import make_transforms, FrameAggregation, ClipAggregation |
| |
|
| | import logging |
| | import threading |
| | from datetime import datetime |
| | from typing import Any, Tuple |
| |
|
| | import os |
| | from dotenv import load_dotenv |
| | import time |
| |
|
| | from langchain.chains.llm import LLMChain |
| | from langchain_community.chat_message_histories import ChatMessageHistory |
| | from langchain.retrievers import ContextualCompressionRetriever |
| | from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline |
| | from langchain_community.vectorstores import Neo4jVector |
| | from langchain_core.callbacks import BaseCallbackHandler |
| | from langchain_core.output_parsers import StrOutputParser |
| | from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| | from langchain_core.runnables import RunnableBranch |
| | from langchain_openai import ChatOpenAI |
| | from langchain_text_splitters import TokenTextSplitter |
| |
|
| | from kg_sys.llm import get_llm |
| | from kg_sys.shared.constants import * |
| | from kg_sys.shared.common_fn import create_graph_database_connection, load_embedding_model |
| | from langchain_core.messages import HumanMessage, AIMessage |
| | from langchain_community.chat_message_histories import Neo4jChatMessageHistory |
| | from langchain.chains import GraphCypherQAChain |
| | import json |
| |
|
| | from kg_sys.shared.constants import CHAT_SYSTEM_TEMPLATE, CHAT_TOKEN_CUT_OFF |
| |
|
| | |
| | load_dotenv() |
| |
|
| | |
| | EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") |
| | NEO4J_URI = "neo4j+s://02616835.databases.neo4j.io" |
| | USERNAME = "neo4j" |
| | PASSWORD = "KBbDcrVF5pQklhEzfy8wXP9a5wIV_IfV7Qdr64SvTkQ" |
| | MODEL = os.getenv("VITE_LLM_MODELS") |
| | EMBEDDING_FUNCTION, _ = load_embedding_model(EMBEDDING_MODEL) |
| |
|
| |
|
| | |
| | def setup_neo4j_connection(): |
| | return create_graph_database_connection(uri=NEO4J_URI, userName=USERNAME, password=PASSWORD, database='neo4j') |
| |
|
| |
|
| | class SessionChatHistory: |
| | history_dict = {} |
| |
|
| | @classmethod |
| | def get_chat_history(cls, session_id): |
| | """Retrieve or create chat message history for a given session ID.""" |
| | if session_id not in cls.history_dict: |
| | logging.info(f"Creating new ChatMessageHistory Local for session ID: {session_id}") |
| | cls.history_dict[session_id] = ChatMessageHistory() |
| | else: |
| | logging.info(f"Retrieved existing ChatMessageHistory Local for session ID: {session_id}") |
| | return cls.history_dict[session_id] |
| |
|
| |
|
| | class CustomCallback(BaseCallbackHandler): |
| |
|
| | def __init__(self): |
| | self.transformed_question = None |
| |
|
| | def on_llm_end( |
| | self, response, **kwargs: Any |
| | ) -> None: |
| | logging.info("question transformed") |
| | self.transformed_question = response.generations[0][0].text.strip() |
| |
|
| |
|
| | def get_history_by_session_id(session_id): |
| | try: |
| | return SessionChatHistory.get_chat_history(session_id) |
| | except Exception as e: |
| | logging.error(f"Failed to get history for session ID '{session_id}': {e}") |
| | raise |
| |
|
| |
|
| | def get_total_tokens(ai_response, llm): |
| | try: |
| | if isinstance(llm, (ChatOpenAI)): |
| | total_tokens = ai_response.response_metadata.get('token_usage', {}).get('total_tokens', 0) |
| | else: |
| | logging.warning(f"Unrecognized language model: {type(llm)}. Returning 0 tokens.") |
| | total_tokens = 0 |
| |
|
| | except Exception as e: |
| | logging.error(f"Error retrieving total tokens: {e}") |
| | total_tokens = 0 |
| |
|
| | return total_tokens |
| |
|
| |
|
| | def clear_chat_history(graph, session_id, local=False): |
| | try: |
| | if not local: |
| | history = Neo4jChatMessageHistory( |
| | graph=graph, |
| | session_id=session_id |
| | ) |
| | else: |
| | history = get_history_by_session_id(session_id) |
| |
|
| | history.clear() |
| |
|
| | return { |
| | "session_id": session_id, |
| | "message": "The chat history has been cleared.", |
| | "user": "chatbot" |
| | } |
| |
|
| | except Exception as e: |
| | logging.error(f"Error clearing chat history for session {session_id}: {e}") |
| | return { |
| | "session_id": session_id, |
| | "message": "Failed to clear chat history.", |
| | "user": "chatbot" |
| | } |
| |
|
| |
|
| | def get_sources_and_chunks(sources_used, docs): |
| | chunkdetails_list = [] |
| | sources_used_set = set(sources_used) |
| | seen_ids_and_scores = set() |
| |
|
| | for doc in docs: |
| | try: |
| | source = doc.metadata.get("source") |
| | chunkdetails = doc.metadata.get("chunkdetails", []) |
| |
|
| | if source in sources_used_set: |
| | for chunkdetail in chunkdetails: |
| | id = chunkdetail.get("id") |
| | score = round(chunkdetail.get("score", 0), 4) |
| |
|
| | id_and_score = (id, score) |
| |
|
| | if id_and_score not in seen_ids_and_scores: |
| | seen_ids_and_scores.add(id_and_score) |
| | chunkdetails_list.append({**chunkdetail, "score": score}) |
| |
|
| | except Exception as e: |
| | logging.error(f"Error processing document: {e}") |
| |
|
| | result = { |
| | 'sources': sources_used, |
| | 'chunkdetails': chunkdetails_list, |
| | } |
| | return result |
| |
|
| |
|
| | def get_rag_chain(llm, system_template=CHAT_SYSTEM_TEMPLATE): |
| | try: |
| | question_answering_prompt = ChatPromptTemplate.from_messages( |
| | [ |
| | ("system", system_template), |
| | MessagesPlaceholder(variable_name="messages"), |
| | ( |
| | "human", |
| | "User question: {input}" |
| | ), |
| | ] |
| | ) |
| |
|
| | question_answering_chain = question_answering_prompt | llm |
| |
|
| | return question_answering_chain |
| |
|
| | except Exception as e: |
| | logging.error(f"Error creating RAG chain: {e}") |
| | raise |
| |
|
| |
|
| | def format_documents(documents, model): |
| | prompt_token_cutoff = 4 |
| | for model_names, value in CHAT_TOKEN_CUT_OFF.items(): |
| | if model in model_names: |
| | prompt_token_cutoff = value |
| | break |
| |
|
| | sorted_documents = sorted(documents, key=lambda doc: doc.state.get("query_similarity_score", 0), reverse=True) |
| | sorted_documents = sorted_documents[:prompt_token_cutoff] |
| |
|
| | formatted_docs = list() |
| | sources = set() |
| | entities = dict() |
| | global_communities = list() |
| |
|
| | for doc in sorted_documents: |
| | try: |
| | source = doc.metadata.get('source', "unknown") |
| | sources.add(source) |
| |
|
| | entities = doc.metadata['entities'] if 'entities' in doc.metadata.keys() else entities |
| | global_communities = doc.metadata[ |
| | "communitydetails"] if 'communitydetails' in doc.metadata.keys() else global_communities |
| |
|
| | formatted_doc = ( |
| | "Document start\n" |
| | f"This Document belongs to the source {source}\n" |
| | f"Content: {doc.page_content}\n" |
| | "Document end\n" |
| | ) |
| | formatted_docs.append(formatted_doc) |
| |
|
| | except Exception as e: |
| | logging.error(f"Error formatting document: {e}") |
| |
|
| | return "\n\n".join(formatted_docs), sources, entities, global_communities |
| |
|
| |
|
| | def process_documents(docs, question, messages, llm, model, chat_mode_settings): |
| | start_time = time.time() |
| |
|
| | try: |
| | formatted_docs, sources, entitydetails, communities = format_documents(docs, model) |
| |
|
| | rag_chain = get_rag_chain(llm=llm) |
| |
|
| | ai_response = rag_chain.invoke({ |
| | "messages": messages[:-1], |
| | "context": formatted_docs, |
| | "input": question |
| | }) |
| |
|
| | result = {'sources': list(), 'nodedetails': dict(), 'entities': dict()} |
| | node_details = {"chunkdetails": list(), "entitydetails": list(), "communitydetails": list()} |
| | entities = {'entityids': list(), "relationshipids": list()} |
| |
|
| | if chat_mode_settings["mode"] == CHAT_ENTITY_VECTOR_MODE: |
| | node_details["entitydetails"] = entitydetails |
| |
|
| | elif chat_mode_settings["mode"] == CHAT_GLOBAL_VECTOR_FULLTEXT_MODE: |
| | node_details["communitydetails"] = communities |
| | else: |
| | sources_and_chunks = get_sources_and_chunks(sources, docs) |
| | result['sources'] = sources_and_chunks['sources'] |
| | node_details["chunkdetails"] = sources_and_chunks["chunkdetails"] |
| | entities.update(entitydetails) |
| |
|
| | result["nodedetails"] = node_details |
| | result["entities"] = entities |
| |
|
| | content = ai_response.content |
| | total_tokens = get_total_tokens(ai_response, llm) |
| |
|
| | predict_time = time.time() - start_time |
| | logging.info(f"Final response predicted in {predict_time:.2f} seconds") |
| |
|
| | except Exception as e: |
| | logging.error(f"Error processing documents: {e}") |
| | raise |
| |
|
| | return content, result, total_tokens, formatted_docs |
| |
|
| |
|
| | def retrieve_documents(doc_retriever, messages): |
| | start_time = time.time() |
| | try: |
| | handler = CustomCallback() |
| | docs = doc_retriever.invoke({"messages": messages}, {"callbacks": [handler]}) |
| | transformed_question = handler.transformed_question |
| | if transformed_question: |
| | logging.info(f"Transformed question : {transformed_question}") |
| | doc_retrieval_time = time.time() - start_time |
| | logging.info(f"Documents retrieved in {doc_retrieval_time:.2f} seconds") |
| |
|
| | except Exception as e: |
| | error_message = f"Error retrieving documents: {str(e)}" |
| | logging.error(error_message) |
| | raise RuntimeError(error_message) |
| |
|
| | return docs, transformed_question |
| |
|
| |
|
| | def create_document_retriever_chain(llm, retriever): |
| | try: |
| | logging.info("Starting to create document retriever chain") |
| |
|
| | query_transform_prompt = ChatPromptTemplate.from_messages( |
| | [ |
| | ("system", QUESTION_TRANSFORM_TEMPLATE), |
| | MessagesPlaceholder(variable_name="messages") |
| | ] |
| | ) |
| |
|
| | output_parser = StrOutputParser() |
| |
|
| | splitter = TokenTextSplitter(chunk_size=CHAT_DOC_SPLIT_SIZE, chunk_overlap=0) |
| | embeddings_filter = EmbeddingsFilter( |
| | embeddings=EMBEDDING_FUNCTION, |
| | similarity_threshold=CHAT_EMBEDDING_FILTER_SCORE_THRESHOLD |
| | ) |
| |
|
| | pipeline_compressor = DocumentCompressorPipeline( |
| | transformers=[splitter, embeddings_filter] |
| | ) |
| |
|
| | compression_retriever = ContextualCompressionRetriever( |
| | base_compressor=pipeline_compressor, base_retriever=retriever |
| | ) |
| |
|
| | query_transforming_retriever_chain = RunnableBranch( |
| | ( |
| | lambda x: len(x.get("messages", [])) == 1, |
| | (lambda x: x["messages"][-1].content) | compression_retriever, |
| | ), |
| | query_transform_prompt | llm | output_parser | compression_retriever, |
| | ).with_config(run_name="chat_retriever_chain") |
| |
|
| | logging.info("Successfully created document retriever chain") |
| | return query_transforming_retriever_chain |
| |
|
| | except Exception as e: |
| | logging.error(f"Error creating document retriever chain: {e}", exc_info=True) |
| | raise |
| |
|
| |
|
| | def initialize_neo4j_vector(graph, chat_mode_settings): |
| | try: |
| | retrieval_query = chat_mode_settings.get("retrieval_query") |
| | index_name = chat_mode_settings.get("index_name") |
| | keyword_index = chat_mode_settings.get("keyword_index", "") |
| | node_label = chat_mode_settings.get("node_label") |
| | embedding_node_property = chat_mode_settings.get("embedding_node_property") |
| | text_node_properties = chat_mode_settings.get("text_node_properties") |
| |
|
| | if not retrieval_query or not index_name: |
| | raise ValueError("Required settings 'retrieval_query' or 'index_name' are missing.") |
| |
|
| | if keyword_index: |
| | neo_db = Neo4jVector.from_existing_graph( |
| | embedding=EMBEDDING_FUNCTION, |
| | index_name=index_name, |
| | retrieval_query=retrieval_query, |
| | graph=graph, |
| | search_type="hybrid", |
| | node_label=node_label, |
| | embedding_node_property=embedding_node_property, |
| | text_node_properties=text_node_properties, |
| | keyword_index_name=keyword_index |
| | ) |
| | logging.info( |
| | f"Successfully retrieved Neo4jVector Fulltext index '{index_name}' and keyword index '{keyword_index}'") |
| | else: |
| | neo_db = Neo4jVector.from_existing_graph( |
| | embedding=EMBEDDING_FUNCTION, |
| | index_name=index_name, |
| | retrieval_query=retrieval_query, |
| | graph=graph, |
| | node_label=node_label, |
| | embedding_node_property=embedding_node_property, |
| | text_node_properties=text_node_properties |
| | ) |
| | logging.info(f"Successfully retrieved Neo4jVector index '{index_name}'") |
| | except Exception as e: |
| | index_name = chat_mode_settings.get("index_name") |
| | logging.error(f"Error retrieving Neo4jVector index {index_name} : {e}") |
| | raise |
| | return neo_db |
| |
|
| |
|
| | def create_retriever(neo_db, document_names, chat_mode_settings, search_k, score_threshold): |
| | if document_names and chat_mode_settings["document_filter"]: |
| | retriever = neo_db.as_retriever( |
| | search_type="similarity_score_threshold", |
| | search_kwargs={ |
| | 'k': search_k, |
| | 'score_threshold': score_threshold, |
| | 'filter': {'fileName': {'$in': document_names}} |
| | } |
| | ) |
| | logging.info( |
| | f"Successfully created retriever with search_k={search_k}, score_threshold={score_threshold} for documents {document_names}") |
| | else: |
| | retriever = neo_db.as_retriever( |
| | search_type="similarity_score_threshold", |
| | search_kwargs={'k': search_k, 'score_threshold': score_threshold} |
| | ) |
| | logging.info(f"Successfully created retriever with search_k={search_k}, score_threshold={score_threshold}") |
| | return retriever |
| |
|
| |
|
| | def get_neo4j_retriever(graph, document_names, chat_mode_settings, score_threshold=CHAT_SEARCH_KWARG_SCORE_THRESHOLD): |
| | try: |
| |
|
| | neo_db = initialize_neo4j_vector(graph, chat_mode_settings) |
| | |
| | search_k = chat_mode_settings["top_k"] |
| | retriever = create_retriever(neo_db, document_names, chat_mode_settings, search_k, score_threshold) |
| | return retriever |
| | except Exception as e: |
| | index_name = chat_mode_settings.get("index_name") |
| | logging.error(f"Error retrieving Neo4jVector index {index_name} or creating retriever: {e}") |
| | raise Exception( |
| | f"An error occurred while retrieving the Neo4jVector index or creating the retriever. Please drop and create a new vector index '{index_name}': {e}") from e |
| |
|
| |
|
| | def setup_chat(model, graph, document_names, chat_mode_settings): |
| | start_time = time.time() |
| | try: |
| | if model == "diffbot": |
| | model = os.getenv('DEFAULT_DIFFBOT_CHAT_MODEL') |
| |
|
| | llm, model_name = get_llm(model=model) |
| | logging.info(f"Model called in chat: {model} (version: {model_name})") |
| |
|
| | retriever = get_neo4j_retriever(graph=graph, chat_mode_settings=chat_mode_settings, |
| | document_names=document_names) |
| | doc_retriever = create_document_retriever_chain(llm, retriever) |
| |
|
| | chat_setup_time = time.time() - start_time |
| | logging.info(f"Chat setup completed in {chat_setup_time:.2f} seconds") |
| |
|
| | except Exception as e: |
| | logging.error(f"Error during chat setup: {e}", exc_info=True) |
| | raise |
| |
|
| | return llm, doc_retriever, model_name |
| |
|
| |
|
| | def process_chat_response(messages, history, question, model, graph, document_names, chat_mode_settings): |
| | try: |
| | llm, doc_retriever, model_version = setup_chat(model, graph, document_names, chat_mode_settings) |
| |
|
| | docs, transformed_question = retrieve_documents(doc_retriever, messages) |
| |
|
| | if docs: |
| | content, result, total_tokens, formatted_docs = process_documents(docs, question, messages, llm, model, |
| | chat_mode_settings) |
| | else: |
| | content = "I couldn't find any relevant documents to answer your question." |
| | result = {"sources": list(), "nodedetails": list(), "entities": list()} |
| | total_tokens = 0 |
| | formatted_docs = "" |
| |
|
| | ai_response = AIMessage(content=content) |
| | messages.append(ai_response) |
| |
|
| | |
| | |
| | logging.info("Summarization thread started.") |
| | |
| | metric_details = {"question": question, "contexts": formatted_docs, "answer": content} |
| | return { |
| | "session_id": "", |
| | "message": content, |
| | "info": { |
| | |
| | "sources": result["sources"], |
| | "model": model_version, |
| | "nodedetails": result["nodedetails"], |
| | "total_tokens": total_tokens, |
| | "response_time": 0, |
| | "mode": chat_mode_settings["mode"], |
| | "entities": result["entities"], |
| | "metric_details": metric_details, |
| | }, |
| |
|
| | "user": "chatbot" |
| | } |
| |
|
| | except Exception as e: |
| | logging.exception(f"Error processing chat response at {datetime.now()}: {str(e)}") |
| | return { |
| | "session_id": "", |
| | "message": "Something went wrong", |
| | "info": { |
| | "metrics": [], |
| | "sources": [], |
| | "nodedetails": [], |
| | "total_tokens": 0, |
| | "response_time": 0, |
| | "error": f"{type(e).__name__}: {str(e)}", |
| | "mode": chat_mode_settings["mode"], |
| | "entities": [], |
| | "metric_details": {}, |
| | }, |
| | "user": "chatbot" |
| | } |
| |
|
| |
|
| | def summarize_and_log(history, stored_messages, llm): |
| | logging.info("Starting summarization in a separate thread.") |
| | if not stored_messages: |
| | logging.info("No messages to summarize.") |
| | return False |
| |
|
| | try: |
| | start_time = time.time() |
| |
|
| | summarization_prompt = ChatPromptTemplate.from_messages( |
| | [ |
| | MessagesPlaceholder(variable_name="chat_history"), |
| | ( |
| | "human", |
| | "Summarize the above chat messages into a concise message, focusing on key points and relevant details that could be useful for future conversations. Exclude all introductions and extraneous information." |
| | ), |
| | ] |
| | ) |
| | summarization_chain = summarization_prompt | llm |
| |
|
| | summary_message = summarization_chain.invoke({"chat_history": stored_messages}) |
| |
|
| | with threading.Lock(): |
| | history.clear() |
| | history.add_user_message("Our current conversation summary till now") |
| | history.add_message(summary_message) |
| |
|
| | history_summarized_time = time.time() - start_time |
| | logging.info(f"Chat History summarized in {history_summarized_time:.2f} seconds") |
| |
|
| | return True |
| |
|
| | except Exception as e: |
| | logging.error(f"An error occurred while summarizing messages: {e}", exc_info=True) |
| | return False |
| |
|
| |
|
| | def create_graph_chain(model, graph): |
| | try: |
| | logging.info(f"Graph QA Chain using LLM model: {model}") |
| |
|
| | cypher_llm, model_name = get_llm(model) |
| | qa_llm, model_name = get_llm(model) |
| | graph_chain = GraphCypherQAChain.from_llm( |
| | cypher_llm=cypher_llm, |
| | qa_llm=qa_llm, |
| | validate_cypher=True, |
| | graph=graph, |
| | |
| | allow_dangerous_requests=True, |
| | return_intermediate_steps=True, |
| | top_k=3 |
| | ) |
| |
|
| | logging.info("GraphCypherQAChain instance created successfully.") |
| | return graph_chain, qa_llm, model_name |
| |
|
| | except Exception as e: |
| | logging.error(f"An error occurred while creating the GraphCypherQAChain instance. : {e}") |
| |
|
| |
|
| | def get_graph_response(graph_chain, question): |
| | try: |
| | cypher_res = graph_chain.invoke({"query": question}) |
| |
|
| | response = cypher_res.get("result") |
| | cypher_query = "" |
| | context = [] |
| |
|
| | for step in cypher_res.get("intermediate_steps", []): |
| | if "query" in step: |
| | cypher_string = step["query"] |
| | cypher_query = cypher_string.replace("cypher\n", "").replace("\n", " ").strip() |
| | elif "context" in step: |
| | context = step["context"] |
| | return { |
| | "response": response, |
| | "cypher_query": cypher_query, |
| | "context": context |
| | } |
| |
|
| | except Exception as e: |
| | logging.error(f"An error occurred while getting the graph response : {e}") |
| |
|
| |
|
| | def process_graph_response(model, graph, question, messages, history): |
| | try: |
| | graph_chain, qa_llm, model_version = create_graph_chain(model, graph) |
| |
|
| | graph_response = get_graph_response(graph_chain, question) |
| |
|
| | ai_response_content = graph_response.get("response", "Something went wrong") |
| | ai_response = AIMessage(content=ai_response_content) |
| |
|
| | messages.append(ai_response) |
| | |
| | |
| | |
| | logging.info("Summarization thread started.") |
| | metric_details = {"question": question, "contexts": graph_response.get("context", ""), |
| | "answer": ai_response_content} |
| | result = { |
| | "session_id": "", |
| | "message": ai_response_content, |
| | "info": { |
| | "model": model_version, |
| | "cypher_query": graph_response.get("cypher_query", ""), |
| | "context": graph_response.get("context", ""), |
| | "mode": "graph", |
| | "response_time": 0, |
| | "metric_details": metric_details, |
| | }, |
| | "user": "chatbot" |
| | } |
| |
|
| | return result |
| |
|
| | except Exception as e: |
| | logging.exception(f"Error processing graph response at {datetime.now()}: {str(e)}") |
| | return { |
| | "session_id": "", |
| | "message": "Something went wrong", |
| | "info": { |
| | "model": model_version, |
| | "cypher_query": "", |
| | "context": "", |
| | "mode": "graph", |
| | "response_time": 0, |
| | "error": f"{type(e).__name__}: {str(e)}" |
| | }, |
| | "user": "chatbot" |
| | } |
| |
|
| |
|
| | def create_neo4j_chat_message_history(graph, session_id, write_access=True): |
| | """ |
| | Creates and returns a Neo4jChatMessageHistory instance. |
| | |
| | """ |
| | try: |
| | if write_access: |
| | history = Neo4jChatMessageHistory( |
| | graph=graph, |
| | session_id=session_id |
| | ) |
| | return history |
| |
|
| | history = get_history_by_session_id(session_id) |
| | return history |
| |
|
| | except Exception as e: |
| | logging.error(f"Error creating Neo4jChatMessageHistory: {e}") |
| | raise |
| |
|
| |
|
| | def get_chat_mode_settings(mode, settings_map=CHAT_MODE_CONFIG_MAP): |
| | default_settings = settings_map[CHAT_DEFAULT_MODE] |
| | try: |
| | chat_mode_settings = settings_map.get(mode, default_settings) |
| | chat_mode_settings["mode"] = mode |
| |
|
| | logging.info(f"Chat mode settings: {chat_mode_settings}") |
| |
|
| | except Exception as e: |
| | logging.error(f"Unexpected error: {e}", exc_info=True) |
| | raise |
| |
|
| | return chat_mode_settings |
| |
|
| |
|
| | def QA_RAG(question, session_id, mode, write_access=True): |
| | logging.info(f"Chat Mode: {mode}") |
| | graph = setup_neo4j_connection() |
| | model = MODEL |
| | document_names = '["AM400_User_guide_1_12.md", "AM400_User_guide_13_24.md", "AM400_User_guide_25_36.md", ' \ |
| | '"AM400_User_guide_37_56.md", "AM400_User_guide_57_63.md", "AM400_User_guide_64_90.md",' \ |
| | '"AM400_User_guide_90_104.md", "AM400_User_guide_105_114.md", "AM400_User_guide_115_140.md",' \ |
| | '"AM400_User_guide_141_153.md", "AM400_User_guide_154_163.md", "AM400_User_guide_164_176.md",' \ |
| | '"AM400_User_guide_177_178.md", "AM400_User_guide_179_211.md", "AM400_User_guide_212_228.md",' \ |
| | '"AM400_User_guide_229_240.md"]' |
| | history = create_neo4j_chat_message_history(graph, session_id, write_access) |
| | messages = history.messages |
| |
|
| | user_question = HumanMessage(content=question) |
| | messages.append(user_question) |
| |
|
| | if mode == CHAT_GRAPH_MODE: |
| | result = process_graph_response(model, graph, question, messages, history) |
| | else: |
| | chat_mode_settings = get_chat_mode_settings(mode=mode) |
| | document_names = list(map(str.strip, json.loads(document_names))) |
| | result = process_chat_response(messages, history, question, model, graph, document_names, |
| | chat_mode_settings) |
| |
|
| | result["session_id"] = session_id |
| | |
| | return result['message'] |
| |
|
| |
|
| | print("Starting script...") |
| |
|
| |
|
| | def validate_video(video_path): |
| | """ |
| | Validate video file and check if frames can be extracted. |
| | Returns tuple of (is_valid, message, num_frames) |
| | """ |
| | try: |
| | cap = cv2.VideoCapture(video_path) |
| | if not cap.isOpened(): |
| | return False, "Failed to open video file", 0 |
| |
|
| | |
| | fps = cap.get(cv2.CAP_PROP_FPS) |
| | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| |
|
| | print(f"Video properties: {width}x{height} @ {fps}fps, {total_frames} frames") |
| |
|
| | if total_frames < 1: |
| | return False, "Video has no frames", 0 |
| |
|
| | if fps <= 0: |
| | return False, "Invalid video FPS", 0 |
| |
|
| | |
| | ret, frame = cap.read() |
| | if not ret or frame is None: |
| | return False, "Failed to read first frame", 0 |
| |
|
| | |
| | cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1) |
| | ret, frame = cap.read() |
| | if not ret or frame is None: |
| | return False, "Failed to read last frame", 0 |
| |
|
| | cap.release() |
| | return True, "Video validated successfully", total_frames |
| |
|
| | except Exception as e: |
| | return False, f"Error validating video: {str(e)}", 0 |
| |
|
| |
|
| | def make_single_video_dataloader(video_path, config): |
| | """Create a dataloader for a single video using CSV file""" |
| | print(f"Creating data loader for video: {video_path}") |
| |
|
| | try: |
| | |
| | cap = cv2.VideoCapture(video_path) |
| | if not cap.isOpened(): |
| | print("Failed to open video file") |
| | return None |
| |
|
| | |
| | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | cap.release() |
| |
|
| | if total_frames < 1: |
| | print("Video has no frames") |
| | return None |
| |
|
| | |
| | |
| | |
| | |
| | csv_path = 'video_list.csv' |
| | with open(csv_path, 'w') as f: |
| | |
| | f.write(f"{video_path} 0\n") |
| |
|
| | print(f"Created CSV at: {csv_path}") |
| | data = pd.read_csv(csv_path, header=None, delimiter=" ") |
| | samples = list(data.values[:, 0]) |
| | print('samples', samples) |
| | labels = list(data.values[:, 1]) |
| | print('labels', labels) |
| | transform = make_transforms( |
| | training=False, |
| | num_views_per_clip=config['num_views_per_segment'], |
| | random_horizontal_flip=False, |
| | random_resize_aspect_ratio=(0.75, 4 / 3), |
| | random_resize_scale=(0.08, 1.0), |
| | reprob=0.25, |
| | auto_augment=True, |
| | motion_shift=False, |
| | crop_size=config['resolution'], |
| | ) |
| |
|
| | print(f"Creating data loader with CSV: {csv_path}") |
| | try: |
| | data_loader, info = init_data( |
| | data='VideoDataset', |
| | root_path=[csv_path], |
| | transform=transform, |
| | batch_size=1, |
| | world_size=1, |
| | rank=0, |
| | clip_len=config['frames_per_clip'], |
| | frame_sample_rate=config['frame_step'], |
| | num_clips=config['num_segments'], |
| | allow_clip_overlap=True, |
| | num_workers=1, |
| | copy_data=False, |
| | drop_last=False |
| | ) |
| | print(f"Data loader info: {info}") |
| | except Exception as e: |
| | print("\n=== Data Loader Creation Error ===") |
| | print(f"Error type: {type(e).__name__}") |
| | print(f"Error message: {str(e)}") |
| | print("\nFull traceback:") |
| | traceback.print_exc(file=sys.stdout) |
| | return None |
| |
|
| | return data_loader |
| |
|
| | except Exception as e: |
| | print(f"Error creating data loader: {str(e)}") |
| | return None |
| |
|
| |
|
| | class VideoClassifier: |
| | def __init__(self, config): |
| | print("Initializing VideoClassifier...") |
| | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Using device: {self.device}") |
| | self.config = config |
| | self.encoder, self.classifier = self.load_models() |
| | self.transform = make_transforms( |
| | training=False, |
| | num_views_per_clip=config['num_views_per_segment'], |
| | crop_size=config['resolution'] |
| | ) |
| |
|
| | def load_checkpoint(self, device, r_path, classifier, opt, scaler): |
| | """ |
| | Load checkpoint for classifier, optimizer and scaler. |
| | |
| | Args: |
| | device: torch device to load the model onto |
| | r_path: path to checkpoint file |
| | classifier: classifier model to load weights into |
| | opt: optimizer to load state into |
| | scaler: gradient scaler for mixed precision training (can be None) |
| | |
| | Returns: |
| | tuple: (classifier, optimizer, scaler, epoch) |
| | """ |
| | try: |
| | print(f'Loading checkpoint from: {r_path}') |
| | checkpoint = torch.load(r_path, map_location=torch.device('cpu')) |
| |
|
| | |
| | epoch = checkpoint.get('epoch', 0) |
| | print(f'Checkpoint from epoch: {epoch}') |
| |
|
| | |
| | if 'classifier' in checkpoint: |
| | state_dict = checkpoint['classifier'] |
| | |
| | new_state_dict = {} |
| | for k, v in state_dict.items(): |
| | if k.startswith('module.'): |
| | new_state_dict[k[7:]] = v |
| | else: |
| | new_state_dict[k] = v |
| |
|
| | try: |
| | msg = classifier.load_state_dict(new_state_dict) |
| | print(f'Loaded classifier state with message: {msg}') |
| | except Exception as e: |
| | print(f'Error loading classifier state dict: {str(e)}') |
| | raise |
| | else: |
| | print('No classifier state found in checkpoint') |
| | raise KeyError('No classifier state in checkpoint') |
| |
|
| | |
| | if 'opt' in checkpoint and opt is not None: |
| | try: |
| | opt.load_state_dict(checkpoint['opt']) |
| | print('Loaded optimizer state') |
| | except Exception as e: |
| | print(f'Error loading optimizer state: {str(e)}') |
| | raise |
| | else: |
| | print('No optimizer state found in checkpoint') |
| |
|
| | |
| | if scaler is not None and 'scaler' in checkpoint: |
| | try: |
| | scaler.load_state_dict(checkpoint['scaler']) |
| | print('Loaded scaler state') |
| | except Exception as e: |
| | print(f'Error loading scaler state: {str(e)}') |
| | raise |
| | elif scaler is not None: |
| | print('No scaler state found in checkpoint') |
| |
|
| | |
| | del checkpoint |
| | torch.cuda.empty_cache() |
| |
|
| | return classifier, opt, scaler, epoch |
| |
|
| | except FileNotFoundError: |
| | print(f'Checkpoint file not found: {r_path}') |
| | return classifier, opt, scaler, 0 |
| | except Exception as e: |
| | print(f'Error loading checkpoint: {str(e)}') |
| | print(f'Traceback: {traceback.format_exc()}') |
| | return classifier, opt, scaler, 0 |
| |
|
| | def load_models(self): |
| | print("Loading models...") |
| | |
| | encoder = init_model( |
| | crop_size=self.config['resolution'], |
| | device=self.device, |
| | pretrained=self.config['pretrained_path'], |
| | model_name=self.config['model_name'], |
| | patch_size=self.config['patch_size'], |
| | tubelet_size=self.config['tubelet_size'], |
| | frames_per_clip=self.config['frames_per_clip'], |
| | uniform_power=self.config['uniform_power'], |
| | checkpoint_key=self.config['checkpoint_key'], |
| | use_SiLU=self.config['use_silu'], |
| | tight_SiLU=self.config['tight_silu'], |
| | use_sdpa=self.config['use_sdpa']) |
| | if self.config['frames_per_clip'] == 1: |
| | |
| | encoder = FrameAggregation(encoder).to(self.device) |
| | else: |
| | |
| | encoder = ClipAggregation( |
| | encoder, |
| | tubelet_size=self.config['tubelet_size'], |
| | attend_across_segments=self.config['attend_across_segments'] |
| | ).to(self.device) |
| | print("Loading pretrained weights...") |
| | |
| | checkpoint = torch.load(self.config['pretrained_path'], map_location=self.device) |
| | encoder.load_state_dict(checkpoint[self.config['checkpoint_key']], strict=False) |
| | encoder.eval() |
| |
|
| | print("Initializing classifier...") |
| |
|
| | |
| | classifier = AttentiveClassifier( |
| | embed_dim=encoder.embed_dim, |
| | num_heads=encoder.num_heads, |
| | depth=1, |
| | num_classes=self.config['num_classes'] |
| | ).to(self.device) |
| |
|
| | optimizer, scaler, scheduler, wd_scheduler = init_opt( |
| | classifier=classifier, |
| | wd=0.01, |
| | start_lr=0.001, |
| | ref_lr=0.001, |
| | final_lr=0.0, |
| | iterations_per_epoch=1, |
| | warmup=0., |
| | num_epochs=20, |
| | use_bfloat16=True) |
| | |
| |
|
| | classifier, optimizer, scaler, epoch = self.load_checkpoint( |
| | device=self.device, |
| | r_path=self.config['classifier_path'], |
| | classifier=classifier, |
| | opt=optimizer, |
| | scaler=scaler |
| | ) |
| | |
| | |
| | classifier.eval() |
| |
|
| | return encoder, classifier |
| |
|
| | def predict(self, video_path): |
| | """ |
| | Predict the class of a video. |
| | """ |
| | print(f"Starting prediction for video: {video_path}") |
| |
|
| | try: |
| | |
| | if not os.path.exists(video_path): |
| | print(f"Error: Video file not found: {video_path}") |
| | return "Error: Video file not found", 0.0 |
| |
|
| | |
| | file_size = os.path.getsize(video_path) |
| | if file_size == 0: |
| | print("Error: Video file is empty") |
| | return "Error: Empty video file", 0.0 |
| |
|
| | print(f"Video file size: {file_size} bytes") |
| |
|
| | |
| | data_loader = make_single_video_dataloader(video_path, self.config) |
| |
|
| | if data_loader is None: |
| | return "Error: Failed to create data loader", 0.0 |
| |
|
| | |
| | with torch.no_grad(): |
| | try: |
| | |
| | batch = next(iter(data_loader)) |
| |
|
| | |
| | clips = [ |
| | [dij.to(self.device, non_blocking=True) for dij in di] |
| | for di in batch[0] |
| | ] |
| | clip_indices = [d.to(self.device, non_blocking=True) for d in batch[2]] |
| |
|
| | print("Running encoder forward pass...") |
| | encoder_outputs = self.encoder(clips, clip_indices) |
| |
|
| | print("Running classifier forward pass...") |
| | if self.config['attend_across_segments']: |
| | outputs = [self.classifier(o) for o in encoder_outputs] |
| | probs = sum([F.softmax(o, dim=1) for o in outputs]) / len(outputs) |
| | else: |
| | outputs = [[self.classifier(ost) for ost in os] for os in encoder_outputs] |
| | probs = sum([sum([F.softmax(ost, dim=1) for ost in os]) |
| | for os in outputs]) / len(outputs) / len(outputs[0]) |
| |
|
| | |
| | pred_class = probs.argmax(dim=1).item() |
| | confidence = probs.max(dim=1).values.item() |
| |
|
| | class_name = self.config['class_names'][pred_class] |
| | print(f'Predicted class: {class_name} (index: {pred_class})') |
| | print(f'Confidence: {confidence:.4f}') |
| |
|
| | return class_name, confidence |
| |
|
| | except StopIteration: |
| | print("Error: No data in data loader") |
| | return "Error: Failed to load video frames", 0.0 |
| |
|
| | except Exception as e: |
| | print(f"Error during prediction: {str(e)}") |
| | return f"Error during prediction: {str(e)}", 0.0 |
| |
|
| | except Exception as e: |
| | print(f"Error in prediction pipeline: {str(e)}") |
| | return f"Error in prediction pipeline: {str(e)}", 0.0 |
| |
|
| |
|
| | def process_video(classifier, video_path): |
| | print("Processing video...") |
| | if video_path is None: |
| | return "No video uploaded", 0.0 |
| |
|
| | mp4_path = None |
| |
|
| | try: |
| | if isinstance(video_path, str): |
| | print(f"Processing video from path: {video_path}") |
| |
|
| | |
| | if video_path.endswith('.webm'): |
| | print("Converting .webm to .mp4 using ffmpeg...") |
| | with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file: |
| | mp4_path = temp_file.name |
| |
|
| | |
| | command = ['ffmpeg', '-y', '-i', video_path, '-c:v', 'libx264', '-preset', 'fast', '-crf', '23', '-c:a', |
| | 'aac', mp4_path] |
| | result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| |
|
| | if result.returncode != 0: |
| | print(f"FFmpeg error: {result.stderr.decode()}") |
| | return "Error during video conversion", 0.0 |
| |
|
| | print(f"Converted video saved to: {mp4_path}") |
| | video_path = mp4_path |
| |
|
| | |
| | try: |
| | return classifier.predict(video_path) |
| | finally: |
| | |
| | if mp4_path and os.path.exists(mp4_path): |
| | os.unlink(mp4_path) |
| | else: |
| | print(f"Error: Unexpected video input type: {type(video_path)}") |
| | return "Error: Invalid video format", 0.0 |
| | except Exception as e: |
| | print(f"Error processing video: {str(e)}") |
| | return f"Error: {str(e)}", 0.0 |
| |
|
| |
|
| | def process_highlights_and_images(qa_reuslt_pure): |
| | if not qa_reuslt_pure: |
| | return "", [], {} |
| |
|
| | response = qa_reuslt_pure |
| |
|
| | |
| | cleaned_lines = [] |
| | print('qa_reuslt_pure', qa_reuslt_pure) |
| | for line in qa_reuslt_pure.strip().split('\n'): |
| | cleaned_line = line.replace("**", "") |
| | cleaned_lines.append(cleaned_line) |
| |
|
| | |
| | graph = setup_neo4j_connection() |
| |
|
| | try: |
| | |
| | captions = [] |
| | images = [] |
| | similarities = [] |
| |
|
| | for line in cleaned_lines: |
| | line_embedding = EMBEDDING_FUNCTION.embed_query(line) |
| |
|
| | query = """ |
| | WITH $line_embedding AS line_embedding |
| | MATCH (f:Figure) |
| | WHERE vector.similarity.cosine(f.embedding, line_embedding) > $similarity_threshold |
| | RETURN f, vector.similarity.cosine(f.embedding, line_embedding) AS similarity |
| | ORDER BY similarity DESC |
| | LIMIT 1 |
| | """ |
| |
|
| | result = graph.query(query, { |
| | "line_embedding": line_embedding, |
| | "similarity_threshold": 0.76 |
| | }) |
| |
|
| | for record in result: |
| | node = record["f"] |
| | similarity = record["similarity"] |
| | if 'img_ref' in node: |
| | images.append(node['img_ref']) |
| | captions.append(node['description']) |
| | similarities.append(similarity) |
| |
|
| | |
| | highlighted_response = response |
| | if captions: |
| | response_phrases = response.split('. ') |
| | response_embeddings = [EMBEDDING_FUNCTION.embed_query(phrase) for phrase in response_phrases] |
| | caption_embeddings = [EMBEDDING_FUNCTION.embed_query(caption) for caption in captions] |
| | similarity_matrix = cosine_similarity(response_embeddings, caption_embeddings) |
| |
|
| | highlighted_response = "" |
| | for i, phrase in enumerate(response_phrases): |
| | max_similarity = np.max(similarity_matrix[i]) |
| | is_last_phrase = (i == len(response_phrases) - 1) |
| |
|
| | if max_similarity >= 0.75: |
| | highlighted_response += f"<mark style='background-color: yellow;'>{phrase}</mark>" |
| | else: |
| | highlighted_response += phrase |
| |
|
| | if not is_last_phrase: |
| | highlighted_response += ". " |
| |
|
| | |
| | image_info = [ |
| | {"caption": caption, "similarity": f"{sim:.3f}"} |
| | for caption, sim in zip(captions, similarities) |
| | ] |
| |
|
| | return highlighted_response, images, image_info |
| |
|
| | finally: |
| | if hasattr(graph, 'close'): |
| | graph.close() |
| |
|
| |
|
| | class AudioAgent: |
| | def __init__(self, api_key: str): |
| | self.client = OpenAI(api_key=api_key) |
| | self.temp_dir = tempfile.mkdtemp() |
| | print('temp dir', self.temp_dir) |
| |
|
| | self.conversation_state = { |
| | 'current_stage': None, |
| | 'video_prediction': None, |
| | 'video_confidence': None, |
| | 'generated_question': None, |
| | 'pure_response': None, |
| | 'highlighted_response': None, |
| | 'gallery': None, |
| | 'image_infos': None, |
| | 'webcam_active': False |
| | } |
| |
|
| | def _get_temp_path(self, filename: str) -> str: |
| | """Generate a temporary file path""" |
| | temp_path = os.path.join(self.temp_dir, filename) |
| | print('temp_path', temp_path) |
| | return temp_path |
| |
|
| | def transcribe_audio(self, audio_path: str) -> str: |
| | """Transcribe audio using Whisper""" |
| | try: |
| | if not os.path.isfile(audio_path): |
| | print(f"Invalid audio path: {audio_path}") |
| | return "" |
| |
|
| | with open(audio_path, "rb") as audio_file: |
| | transcript = self.client.audio.transcriptions.create( |
| | model="whisper-1", |
| | file=audio_file |
| | ) |
| | print('transcript text', transcript.text) |
| | return transcript.text |
| | except Exception as e: |
| | print(f"Error in transcription: {e}") |
| | return "" |
| |
|
| | def text_to_speech(self, text: str, filename: str) -> str: |
| | """Convert text to speech using OpenAI's TTS""" |
| | try: |
| | output_path = self._get_temp_path(filename) |
| | print('output_path', output_path) |
| | response = self.client.audio.speech.create( |
| | model="tts-1", |
| | voice="alloy", |
| | input=text |
| | ) |
| | response.stream_to_file(output_path) |
| | return output_path |
| | except Exception as e: |
| | print(f"Error in TTS conversion: {e}") |
| | return "" |
| |
|
| | def determine_action(self, instruction: str) -> str: |
| | """Determine what action the user wants to take""" |
| | messages = [ |
| | { |
| | "role": "system", |
| | "content": f"""Determine what action the user wants to take based on the current conversation state: |
| | Current state: {self.conversation_state['current_stage']} |
| | Video prediction: {self.conversation_state['video_prediction']} |
| | |
| | Possible actions: |
| | - analyze_video: If they want to analyze a video |
| | - ask_question: If they are asking their own question |
| | - answer_questions: If they want to answer questions |
| | - generate_highlights: If they want to process answers for highlights |
| | - unclear: If the request is unclear |
| | |
| | Return only one of these action names.""" |
| | }, |
| | {"role": "user", "content": str(instruction)} |
| | ] |
| |
|
| | try: |
| | response = self.client.chat.completions.create( |
| | model="gpt-4-turbo-preview", |
| | messages=messages |
| | ) |
| | return response.choices[0].message.content.strip() |
| | except Exception as e: |
| | print(f"Error determining action: {e}") |
| | return "unclear" |
| |
|
| | def extract_question(self, instruction: str) -> str: |
| | """Extract the actual question from user's voice input""" |
| | messages = [ |
| | { |
| | "role": "system", |
| | "content": "Extract the actual question from the user's input. Return only the question itself." |
| | }, |
| | {"role": "user", "content": str(instruction)} |
| | ] |
| |
|
| | try: |
| | response = self.client.chat.completions.create( |
| | model="gpt-4-turbo-preview", |
| | messages=messages |
| | ) |
| | return response.choices[0].message.content.strip() |
| | except Exception as e: |
| | print(f"Error extracting question: {e}") |
| | return instruction |
| |
|
| | def process_request(self, instruction: str, **kwargs) -> Dict[str, Any]: |
| | """Process user instruction based on conversation state""" |
| | try: |
| | action = self.determine_action(instruction) |
| | results = {} |
| | final_audio_path = None |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if action == "analyze_video": |
| | video_input = kwargs.get('webcam_input') |
| | print(f"Video input received: {video_input}, type: {type(video_input)}") |
| |
|
| | video_path = None |
| | |
| | if isinstance(video_input, tuple): |
| | |
| | video_path = video_input[0] if video_input else None |
| | elif isinstance(video_input, str): |
| | |
| | video_path = video_input |
| | elif isinstance(video_input, dict) and 'name' in video_input: |
| | |
| | video_path = video_input['name'] |
| |
|
| | print(f"Processed video path: {video_path}") |
| |
|
| | if video_path and os.path.isfile(video_path): |
| | classifier = kwargs.get('classifier') |
| | if classifier: |
| | video_result = process_video(classifier, video_path) |
| | self.conversation_state['video_prediction'] = video_result[0] |
| | self.conversation_state['video_confidence'] = video_result[1] |
| | self.conversation_state['current_stage'] = 'video_analyzed' |
| |
|
| | results.update({ |
| | 'video_prediction': video_result[0], |
| | 'video_confidence': video_result[1] |
| | }) |
| |
|
| | final_audio_path = self.text_to_speech( |
| | f"Video analysis complete. Detected action: {video_result[0]} with confidence {video_result[1]:.2f}. Would you like to ask questions about this activity?", |
| | "response_video.mp3" |
| | ) |
| |
|
| | elif action == "ask_question": |
| | if self.conversation_state['video_prediction']: |
| | |
| | custom_question = self.extract_question(instruction) |
| | self.conversation_state['generated_question'] = custom_question |
| | self.conversation_state['current_stage'] = 'questions_generated' |
| | results['generated_question'] = custom_question |
| | final_audio_path = self.text_to_speech( |
| | f"I understand your question: '{custom_question}'. Would you like me to answer it?", |
| | "response_custom_question.mp3" |
| | ) |
| | else: |
| | return {"error": "Please analyze a video first"} |
| |
|
| | elif action == "answer_questions": |
| | if self.conversation_state['generated_question']: |
| | pure_response = QA_RAG( |
| | self.conversation_state['generated_question'], |
| | kwargs.get('session_id', 1), |
| | kwargs.get('mode', 'CHAT_VECTOR_GRAPH_MODE') |
| | ) |
| | self.conversation_state['pure_response'] = pure_response |
| | self.conversation_state['current_stage'] = 'questions_answered' |
| |
|
| | results.update({ |
| | 'pure_response': pure_response |
| | }) |
| |
|
| | final_audio_path = self.text_to_speech( |
| | f"Sure! I've answered the questions. {pure_response}. Would you like me to generate highlights and related images?", |
| | "response_qa.mp3" |
| | ) |
| | else: |
| | return {"error": "Please generate questions first"} |
| |
|
| | elif action == "generate_highlights": |
| | if self.conversation_state['pure_response']: |
| | highlighted_response, gallery, image_infos = process_highlights_and_images( |
| | self.conversation_state['pure_response'] |
| | ) |
| | self.conversation_state.update({ |
| | 'highlighted_response': highlighted_response, |
| | 'gallery': gallery, |
| | 'image_infos': image_infos, |
| | 'current_stage': 'highlights_generated' |
| | }) |
| |
|
| | results.update({ |
| | 'highlighted_response': highlighted_response, |
| | 'gallery': gallery, |
| | 'image_infos': image_infos |
| | }) |
| |
|
| | final_audio_path = self.text_to_speech( |
| | "I've generated highlights and retrieved related images. Is there anything else you'd like to know?", |
| | "response_highlights.mp3" |
| | ) |
| | else: |
| | return {"error": "Please answer questions first"} |
| |
|
| | else: |
| | suggestion = "" |
| | if not self.conversation_state['current_stage']: |
| | suggestion = "You can start by asking me to analyze a video." |
| | elif self.conversation_state['current_stage'] == 'video_analyzed': |
| | suggestion = "You can ask me to generate questions about the activity." |
| | elif self.conversation_state['current_stage'] == 'questions_generated': |
| | suggestion = "You can ask me to answer the questions." |
| | elif self.conversation_state['current_stage'] == 'questions_answered': |
| | suggestion = "You can ask me to generate highlights and images." |
| |
|
| | return { |
| | "error": f"I'm not sure what you want me to do. {suggestion}" |
| | } |
| |
|
| | return { |
| | "results": results, |
| | "audio_path": final_audio_path, |
| | "action": action, |
| | **results |
| | } |
| |
|
| | except Exception as e: |
| | print(f"Error in process_request: {e}") |
| | return {"error": f"Error processing request: {e}"} |
| |
|
| |
|
| | def process_audio_input( |
| | agent, |
| | audio_path: str, |
| | webcam_input: str = None, |
| | classifier=None, |
| | mode: str = 'CHAT_VECTOR_GRAPH_MODE', |
| | session_id: int = 1 |
| | ) -> Tuple[Dict[str, Any], str]: |
| | """Main function to handle audio input and generate audio output""" |
| | |
| |
|
| | |
| | transcribed_text = agent.transcribe_audio(audio_path) |
| | if not transcribed_text: |
| | return {"error": "Error in transcription"}, "" |
| |
|
| | |
| | result = agent.process_request( |
| | transcribed_text, |
| | webcam_input=webcam_input, |
| | classifier=classifier, |
| | mode=mode, |
| | session_id=session_id |
| | ) |
| |
|
| | if "error" in result: |
| | return result, "" |
| |
|
| | return result, result.get("audio_path", "") |
| |
|
| |
|
| | def main(): |
| | print("Starting main function...") |
| | |
| | |
| | import os |
| | import urllib.request |
| | |
| | |
| | models = { |
| | "pretrained_model/vith16.pth.tar": "https://dl.fbaipublicfiles.com/jepa/vith16/vith16.pth.tar", |
| | "pretrained_model/v-jepa-224-classifier/jepa-latest.pth.tar": "https://huggingface.co/IELTS8/pretrained_model/resolve/main/v-jepa-224-classifier/jepa-latest.pth.tar" |
| | } |
| | |
| | |
| | def ensure_directory_exists(file_path): |
| | directory = os.path.dirname(file_path) |
| | if not os.path.exists(directory): |
| | os.makedirs(directory) |
| | |
| | |
| | def download_file(file_path, url): |
| | if not os.path.exists(file_path): |
| | print(f"Downloading {file_path} from {url}...") |
| | ensure_directory_exists(file_path) |
| | urllib.request.urlretrieve(url, file_path) |
| | print(f"Downloaded: {file_path}") |
| | else: |
| | print(f"File already exists: {file_path}") |
| |
|
| | def update_video(video_path): |
| | return video_path |
| |
|
| | |
| | for file_path, url in models.items(): |
| | download_file(file_path, url) |
| | |
| | |
| | model_config = { |
| | 'model_name': 'vit_huge', |
| | 'resolution': 224, |
| | 'patch_size': 16, |
| | 'frames_per_clip': 16, |
| | 'tubelet_size': 2, |
| | 'uniform_power': True, |
| | 'use_sdpa': True, |
| | 'use_silu': False, |
| | 'tight_silu': False, |
| | 'num_classes': 11, |
| | 'pretrained_path': 'pretrained_model/vith16.pth.tar', |
| | 'classifier_path': 'pretrained_model/v-jepa-224-classifier/jepa-latest.pth.tar', |
| | 'checkpoint_key': 'target_encoder', |
| | 'num_segments': 2, |
| | 'num_views_per_segment': 3, |
| | 'frame_step': 3, |
| | 'attend_across_segments': True, |
| | 'class_names': [ |
| | 'clean the build plate', 'clean the recoater', 'insert the build plate screw', |
| | 'measure the build plate thickness', 'open the process chamber door', |
| | 'operate valves of the large safe change filter', |
| | 'place build plate inside the process chamber', |
| | 'place recoater into the process chamber', 'press reset button on the control panel', |
| | 'press stop button on the control panel', |
| | 'wipe laser window inside the process chamber' |
| | ] |
| | } |
| | world_size, rank = init_distributed() |
| | print(f'Initialized (rank/world-size) {rank}/{world_size}') |
| | print("Creating classifier...") |
| | classifier = VideoClassifier(model_config) |
| |
|
| | |
| | |
| | |
| | |
| | custom_theme = gr.themes.Soft().set( |
| | body_background_fill="rgba(255, 255, 255, 0.3)", |
| | block_background_fill="rgba(255, 255, 255, 0.3)", |
| | panel_background_fill="rgba(255, 255, 255, 0.3)", |
| | |
| | block_border_width="1px", |
| | block_border_color="rgba(128, 128, 128, 0.3)", |
| | block_radius="8px", |
| | |
| | block_title_text_color="rgba(0, 0, 0, 0.8)", |
| | block_label_text_color="rgba(0, 0, 0, 0.7)", |
| | body_text_color="rgba(0, 0, 0, 0.8)" |
| | ) |
| |
|
| | print("Setting up Gradio interface...") |
| | sample_videos = ["example_1.mp4", "example_2.mp4", "example_3.mp4", "example_4.mp4", "example_5.mp4"] |
| | with gr.Blocks(theme=custom_theme, css=""" |
| | .gradio-container { |
| | background: rgba(255, 255, 255, 0.7) !important; |
| | } |
| | .block { |
| | background: rgba(255, 255, 255, 0.5) !important; |
| | backdrop-filter: blur(10px); |
| | } |
| | .panel { |
| | background: rgba(255, 255, 255, 0.3) !important; |
| | } |
| | """) as demo: |
| | gr.Markdown("# 🏭 Sim2Real: Action Recognition (Synthetic Data) & Knowledge Guidance (Knowledge Graph)") |
| | |
| | last_processed_time = gr.State(time.time()) |
| | is_processing = gr.State(False) |
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | gr.Markdown("### 🎬 Operation Recognition") |
| | with gr.Row(): |
| | with gr.Column(scale=3): |
| | video_input = gr.Video( |
| | label="Upload Video", |
| | sources=["upload", "webcam"], |
| | height=400, |
| | width=600 |
| | ) |
| | with gr.Column(scale=1): |
| | |
| | with gr.Group(): |
| | gr.Examples( |
| | examples=sample_videos, |
| | inputs=video_input, |
| | examples_per_page=5, |
| | label="Sample Videos" |
| | ) |
| | with gr.Row(): |
| | audio_input = gr.Audio( |
| | sources=["microphone"], |
| | type="filepath", |
| | label="Speak your request", |
| | streaming=False |
| | ) |
| |
|
| | with gr.Row(): |
| | text_output = gr.Textbox(label="Response") |
| | audio_output = gr.Audio( |
| | label="Audio Response", |
| | autoplay=True |
| | ) |
| |
|
| | status = gr.Textbox(label="Status", value="Ready") |
| |
|
| | |
| | video_prediction = gr.Textbox(visible=False) |
| | video_confidence = gr.Number(visible=False) |
| | generated_question = gr.Textbox(visible=False) |
| |
|
| | with gr.Column(scale=3): |
| | gr.Markdown("### 🔍 Knowledge Guidance") |
| | with gr.Row(): |
| | mode = gr.Dropdown( |
| | choices=[ |
| | 'CHAT_VECTOR_GRAPH_MODE', |
| | 'CHAT_GRAPH_MODE', |
| | 'CHAT_VECTOR_MODE', |
| | 'CHAT_FULLTEXT_MODE', |
| | 'CHAT_VECTOR_GRAPH_FULLTEXT_MODE' |
| | ], |
| | label="Select Mode", |
| | value='CHAT_VECTOR_GRAPH_MODE', |
| | visible=False, |
| | ) |
| | session_id = gr.Number( |
| | label="Session ID", |
| | value=1, |
| | visible=False |
| | ) |
| |
|
| | pure_response = gr.Markdown(visible=False) |
| | highlighted_response = gr.HTML(label="Highlighted Response") |
| |
|
| | with gr.Row(): |
| | gallery = gr.Gallery(label="Related Images", height=150) |
| | image_infos = gr.JSON(label="Image Information", open=False) |
| |
|
| | agent = AudioAgent(api_key=os.getenv("OPENAI_API_KEY")) |
| |
|
| | def process_with_status(audio_path, last_time, is_proc, video_path, mode_value, session_id_value): |
| | |
| | print("=== Input Debug ===") |
| | print(f"Audio path: {audio_path}, type: {type(audio_path)}") |
| | print(f"Video path: {video_path}, type: {type(video_path)}") |
| | print(f"Mode value: {mode_value}") |
| | print(f"Session ID: {session_id_value}") |
| |
|
| | |
| | video_file = None |
| | if isinstance(video_path, tuple) and len(video_path) > 0: |
| | video_file = video_path[0] |
| | print(f"Extracted video file from tuple: {video_file}") |
| | elif isinstance(video_path, str): |
| | video_file = video_path |
| | print(f"Using video path directly: {video_file}") |
| | else: |
| | print(f"Unexpected video_path type: {type(video_path)}, value: {video_path}") |
| |
|
| | current_time = time.time() |
| |
|
| | |
| | if current_time - last_time < 4 or is_proc: |
| | return { |
| | status: "Processing previous request...", |
| | last_processed_time: last_time, |
| | is_processing: is_proc |
| | } |
| |
|
| | |
| | is_proc = True |
| | |
| | if audio_path is None: |
| | return { |
| | status: "No audio detected", |
| | last_processed_time: current_time, |
| | is_processing: False |
| | } |
| |
|
| | result, audio_path = process_audio_input( |
| | agent=agent, |
| | audio_path=audio_path, |
| | webcam_input=video_path, |
| | classifier=classifier, |
| | mode=mode_value, |
| | session_id=session_id_value |
| | ) |
| |
|
| | print("=== Output Debug ===") |
| | print(f"Result: {result}") |
| | print(f"Audio response path: {audio_path}") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if "error" in result: |
| | return result["error"], None, "", 0, [], "", "", [], {}, f"Error: {result['error']}" |
| | |
| | return ( |
| | result, |
| | audio_path, |
| | result.get("video_prediction", ""), |
| | result.get("video_confidence", 0), |
| | result.get("generated_question", []), |
| | result.get("pure_response", ""), |
| | result.get("highlighted_response", ""), |
| | result.get("gallery", []), |
| | result.get("image_infos", {}), |
| | "Ready" |
| | ) |
| |
|
| | audio_input.change( |
| | fn=process_with_status, |
| | inputs=[ |
| | audio_input, |
| | last_processed_time, |
| | is_processing, |
| | video_input, |
| | mode, |
| | session_id |
| | ], |
| | outputs=[ |
| | text_output, |
| | audio_output, |
| | video_prediction, |
| | video_confidence, |
| | generated_question, |
| | pure_response, |
| | highlighted_response, |
| | gallery, |
| | image_infos, |
| | status |
| | ] |
| | ) |
| |
|
| | print("Launching interface...") |
| | return demo |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("Script started") |
| | demo = main() |
| | demo.launch(share=True) |
| |
|