metalmind / app.py
IELTS8's picture
Update app.py
ac0e8cd verified
#!/usr/bin/env python3
import os
import subprocess
import sys
import traceback
import gradio as gr
from openai import OpenAI
from typing import Dict
import cv2
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from video_classification_frozen.eval import init_opt, init_model
from src.datasets.data_manager import init_data
from src.utils.distributed import init_distributed
import torch
import torch.nn.functional as F
import tempfile
from src.models.attentive_pooler import AttentiveClassifier
from video_classification_frozen.utils import make_transforms, FrameAggregation, ClipAggregation
import logging
import threading
from datetime import datetime
from typing import Any, Tuple
import os
from dotenv import load_dotenv
import time
from langchain.chains.llm import LLMChain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
from langchain_community.vectorstores import Neo4jVector
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableBranch
from langchain_openai import ChatOpenAI
from langchain_text_splitters import TokenTextSplitter
from kg_sys.llm import get_llm
from kg_sys.shared.constants import *
from kg_sys.shared.common_fn import create_graph_database_connection, load_embedding_model
from langchain_core.messages import HumanMessage, AIMessage
from langchain_community.chat_message_histories import Neo4jChatMessageHistory
from langchain.chains import GraphCypherQAChain
import json
from kg_sys.shared.constants import CHAT_SYSTEM_TEMPLATE, CHAT_TOKEN_CUT_OFF
# Load environment variables
load_dotenv()
# Constants
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
NEO4J_URI = "neo4j+s://02616835.databases.neo4j.io"
USERNAME = "neo4j"
PASSWORD = "KBbDcrVF5pQklhEzfy8wXP9a5wIV_IfV7Qdr64SvTkQ"
MODEL = os.getenv("VITE_LLM_MODELS")
EMBEDDING_FUNCTION, _ = load_embedding_model(EMBEDDING_MODEL)
# Helper functions from your code
def setup_neo4j_connection():
return create_graph_database_connection(uri=NEO4J_URI, userName=USERNAME, password=PASSWORD, database='neo4j')
class SessionChatHistory:
history_dict = {}
@classmethod
def get_chat_history(cls, session_id):
"""Retrieve or create chat message history for a given session ID."""
if session_id not in cls.history_dict:
logging.info(f"Creating new ChatMessageHistory Local for session ID: {session_id}")
cls.history_dict[session_id] = ChatMessageHistory()
else:
logging.info(f"Retrieved existing ChatMessageHistory Local for session ID: {session_id}")
return cls.history_dict[session_id]
class CustomCallback(BaseCallbackHandler):
def __init__(self):
self.transformed_question = None
def on_llm_end(
self, response, **kwargs: Any
) -> None:
logging.info("question transformed")
self.transformed_question = response.generations[0][0].text.strip()
def get_history_by_session_id(session_id):
try:
return SessionChatHistory.get_chat_history(session_id)
except Exception as e:
logging.error(f"Failed to get history for session ID '{session_id}': {e}")
raise
def get_total_tokens(ai_response, llm):
try:
if isinstance(llm, (ChatOpenAI)):
total_tokens = ai_response.response_metadata.get('token_usage', {}).get('total_tokens', 0)
else:
logging.warning(f"Unrecognized language model: {type(llm)}. Returning 0 tokens.")
total_tokens = 0
except Exception as e:
logging.error(f"Error retrieving total tokens: {e}")
total_tokens = 0
return total_tokens
def clear_chat_history(graph, session_id, local=False):
try:
if not local:
history = Neo4jChatMessageHistory(
graph=graph,
session_id=session_id
)
else:
history = get_history_by_session_id(session_id)
history.clear()
return {
"session_id": session_id,
"message": "The chat history has been cleared.",
"user": "chatbot"
}
except Exception as e:
logging.error(f"Error clearing chat history for session {session_id}: {e}")
return {
"session_id": session_id,
"message": "Failed to clear chat history.",
"user": "chatbot"
}
def get_sources_and_chunks(sources_used, docs):
chunkdetails_list = []
sources_used_set = set(sources_used)
seen_ids_and_scores = set()
for doc in docs:
try:
source = doc.metadata.get("source")
chunkdetails = doc.metadata.get("chunkdetails", [])
if source in sources_used_set:
for chunkdetail in chunkdetails:
id = chunkdetail.get("id")
score = round(chunkdetail.get("score", 0), 4)
id_and_score = (id, score)
if id_and_score not in seen_ids_and_scores:
seen_ids_and_scores.add(id_and_score)
chunkdetails_list.append({**chunkdetail, "score": score})
except Exception as e:
logging.error(f"Error processing document: {e}")
result = {
'sources': sources_used,
'chunkdetails': chunkdetails_list,
}
return result
def get_rag_chain(llm, system_template=CHAT_SYSTEM_TEMPLATE):
try:
question_answering_prompt = ChatPromptTemplate.from_messages(
[
("system", system_template),
MessagesPlaceholder(variable_name="messages"),
(
"human",
"User question: {input}"
),
]
)
question_answering_chain = question_answering_prompt | llm
return question_answering_chain
except Exception as e:
logging.error(f"Error creating RAG chain: {e}")
raise
def format_documents(documents, model):
prompt_token_cutoff = 4
for model_names, value in CHAT_TOKEN_CUT_OFF.items():
if model in model_names:
prompt_token_cutoff = value
break
sorted_documents = sorted(documents, key=lambda doc: doc.state.get("query_similarity_score", 0), reverse=True)
sorted_documents = sorted_documents[:prompt_token_cutoff]
formatted_docs = list()
sources = set()
entities = dict()
global_communities = list()
for doc in sorted_documents:
try:
source = doc.metadata.get('source', "unknown")
sources.add(source)
entities = doc.metadata['entities'] if 'entities' in doc.metadata.keys() else entities
global_communities = doc.metadata[
"communitydetails"] if 'communitydetails' in doc.metadata.keys() else global_communities
formatted_doc = (
"Document start\n"
f"This Document belongs to the source {source}\n"
f"Content: {doc.page_content}\n"
"Document end\n"
)
formatted_docs.append(formatted_doc)
except Exception as e:
logging.error(f"Error formatting document: {e}")
return "\n\n".join(formatted_docs), sources, entities, global_communities
def process_documents(docs, question, messages, llm, model, chat_mode_settings):
start_time = time.time()
try:
formatted_docs, sources, entitydetails, communities = format_documents(docs, model)
rag_chain = get_rag_chain(llm=llm)
ai_response = rag_chain.invoke({
"messages": messages[:-1],
"context": formatted_docs,
"input": question
})
result = {'sources': list(), 'nodedetails': dict(), 'entities': dict()}
node_details = {"chunkdetails": list(), "entitydetails": list(), "communitydetails": list()}
entities = {'entityids': list(), "relationshipids": list()}
if chat_mode_settings["mode"] == CHAT_ENTITY_VECTOR_MODE:
node_details["entitydetails"] = entitydetails
elif chat_mode_settings["mode"] == CHAT_GLOBAL_VECTOR_FULLTEXT_MODE:
node_details["communitydetails"] = communities
else:
sources_and_chunks = get_sources_and_chunks(sources, docs)
result['sources'] = sources_and_chunks['sources']
node_details["chunkdetails"] = sources_and_chunks["chunkdetails"]
entities.update(entitydetails)
result["nodedetails"] = node_details
result["entities"] = entities
content = ai_response.content
total_tokens = get_total_tokens(ai_response, llm)
predict_time = time.time() - start_time
logging.info(f"Final response predicted in {predict_time:.2f} seconds")
except Exception as e:
logging.error(f"Error processing documents: {e}")
raise
return content, result, total_tokens, formatted_docs
def retrieve_documents(doc_retriever, messages):
start_time = time.time()
try:
handler = CustomCallback()
docs = doc_retriever.invoke({"messages": messages}, {"callbacks": [handler]})
transformed_question = handler.transformed_question
if transformed_question:
logging.info(f"Transformed question : {transformed_question}")
doc_retrieval_time = time.time() - start_time
logging.info(f"Documents retrieved in {doc_retrieval_time:.2f} seconds")
except Exception as e:
error_message = f"Error retrieving documents: {str(e)}"
logging.error(error_message)
raise RuntimeError(error_message)
return docs, transformed_question
def create_document_retriever_chain(llm, retriever):
try:
logging.info("Starting to create document retriever chain")
query_transform_prompt = ChatPromptTemplate.from_messages(
[
("system", QUESTION_TRANSFORM_TEMPLATE),
MessagesPlaceholder(variable_name="messages")
]
)
output_parser = StrOutputParser()
splitter = TokenTextSplitter(chunk_size=CHAT_DOC_SPLIT_SIZE, chunk_overlap=0)
embeddings_filter = EmbeddingsFilter(
embeddings=EMBEDDING_FUNCTION,
similarity_threshold=CHAT_EMBEDDING_FILTER_SCORE_THRESHOLD
)
pipeline_compressor = DocumentCompressorPipeline(
transformers=[splitter, embeddings_filter]
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=pipeline_compressor, base_retriever=retriever
)
query_transforming_retriever_chain = RunnableBranch(
(
lambda x: len(x.get("messages", [])) == 1,
(lambda x: x["messages"][-1].content) | compression_retriever,
),
query_transform_prompt | llm | output_parser | compression_retriever,
).with_config(run_name="chat_retriever_chain")
logging.info("Successfully created document retriever chain")
return query_transforming_retriever_chain
except Exception as e:
logging.error(f"Error creating document retriever chain: {e}", exc_info=True)
raise
def initialize_neo4j_vector(graph, chat_mode_settings):
try:
retrieval_query = chat_mode_settings.get("retrieval_query")
index_name = chat_mode_settings.get("index_name")
keyword_index = chat_mode_settings.get("keyword_index", "")
node_label = chat_mode_settings.get("node_label")
embedding_node_property = chat_mode_settings.get("embedding_node_property")
text_node_properties = chat_mode_settings.get("text_node_properties")
if not retrieval_query or not index_name:
raise ValueError("Required settings 'retrieval_query' or 'index_name' are missing.")
if keyword_index:
neo_db = Neo4jVector.from_existing_graph(
embedding=EMBEDDING_FUNCTION,
index_name=index_name,
retrieval_query=retrieval_query,
graph=graph,
search_type="hybrid",
node_label=node_label,
embedding_node_property=embedding_node_property,
text_node_properties=text_node_properties,
keyword_index_name=keyword_index
)
logging.info(
f"Successfully retrieved Neo4jVector Fulltext index '{index_name}' and keyword index '{keyword_index}'")
else:
neo_db = Neo4jVector.from_existing_graph(
embedding=EMBEDDING_FUNCTION,
index_name=index_name,
retrieval_query=retrieval_query,
graph=graph,
node_label=node_label,
embedding_node_property=embedding_node_property,
text_node_properties=text_node_properties
)
logging.info(f"Successfully retrieved Neo4jVector index '{index_name}'")
except Exception as e:
index_name = chat_mode_settings.get("index_name")
logging.error(f"Error retrieving Neo4jVector index {index_name} : {e}")
raise
return neo_db
def create_retriever(neo_db, document_names, chat_mode_settings, search_k, score_threshold):
if document_names and chat_mode_settings["document_filter"]:
retriever = neo_db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
'k': search_k,
'score_threshold': score_threshold,
'filter': {'fileName': {'$in': document_names}}
}
)
logging.info(
f"Successfully created retriever with search_k={search_k}, score_threshold={score_threshold} for documents {document_names}")
else:
retriever = neo_db.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={'k': search_k, 'score_threshold': score_threshold}
)
logging.info(f"Successfully created retriever with search_k={search_k}, score_threshold={score_threshold}")
return retriever
def get_neo4j_retriever(graph, document_names, chat_mode_settings, score_threshold=CHAT_SEARCH_KWARG_SCORE_THRESHOLD):
try:
neo_db = initialize_neo4j_vector(graph, chat_mode_settings)
# document_names= list(map(str.strip, json.loads(document_names)))
search_k = chat_mode_settings["top_k"]
retriever = create_retriever(neo_db, document_names, chat_mode_settings, search_k, score_threshold)
return retriever
except Exception as e:
index_name = chat_mode_settings.get("index_name")
logging.error(f"Error retrieving Neo4jVector index {index_name} or creating retriever: {e}")
raise Exception(
f"An error occurred while retrieving the Neo4jVector index or creating the retriever. Please drop and create a new vector index '{index_name}': {e}") from e
def setup_chat(model, graph, document_names, chat_mode_settings):
start_time = time.time()
try:
if model == "diffbot":
model = os.getenv('DEFAULT_DIFFBOT_CHAT_MODEL')
llm, model_name = get_llm(model=model)
logging.info(f"Model called in chat: {model} (version: {model_name})")
retriever = get_neo4j_retriever(graph=graph, chat_mode_settings=chat_mode_settings,
document_names=document_names)
doc_retriever = create_document_retriever_chain(llm, retriever)
chat_setup_time = time.time() - start_time
logging.info(f"Chat setup completed in {chat_setup_time:.2f} seconds")
except Exception as e:
logging.error(f"Error during chat setup: {e}", exc_info=True)
raise
return llm, doc_retriever, model_name
def process_chat_response(messages, history, question, model, graph, document_names, chat_mode_settings):
try:
llm, doc_retriever, model_version = setup_chat(model, graph, document_names, chat_mode_settings)
docs, transformed_question = retrieve_documents(doc_retriever, messages)
if docs:
content, result, total_tokens, formatted_docs = process_documents(docs, question, messages, llm, model,
chat_mode_settings)
else:
content = "I couldn't find any relevant documents to answer your question."
result = {"sources": list(), "nodedetails": list(), "entities": list()}
total_tokens = 0
formatted_docs = ""
ai_response = AIMessage(content=content)
messages.append(ai_response)
# summarization_thread = threading.Thread(target=summarize_and_log, args=(history, messages, llm))
# summarization_thread.start()
logging.info("Summarization thread started.")
# summarize_and_log(history, messages, llm)
metric_details = {"question": question, "contexts": formatted_docs, "answer": content}
return {
"session_id": "",
"message": content,
"info": {
# "metrics" : metrics,
"sources": result["sources"],
"model": model_version,
"nodedetails": result["nodedetails"],
"total_tokens": total_tokens,
"response_time": 0,
"mode": chat_mode_settings["mode"],
"entities": result["entities"],
"metric_details": metric_details,
},
"user": "chatbot"
}
except Exception as e:
logging.exception(f"Error processing chat response at {datetime.now()}: {str(e)}")
return {
"session_id": "",
"message": "Something went wrong",
"info": {
"metrics": [],
"sources": [],
"nodedetails": [],
"total_tokens": 0,
"response_time": 0,
"error": f"{type(e).__name__}: {str(e)}",
"mode": chat_mode_settings["mode"],
"entities": [],
"metric_details": {},
},
"user": "chatbot"
}
def summarize_and_log(history, stored_messages, llm):
logging.info("Starting summarization in a separate thread.")
if not stored_messages:
logging.info("No messages to summarize.")
return False
try:
start_time = time.time()
summarization_prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder(variable_name="chat_history"),
(
"human",
"Summarize the above chat messages into a concise message, focusing on key points and relevant details that could be useful for future conversations. Exclude all introductions and extraneous information."
),
]
)
summarization_chain = summarization_prompt | llm
summary_message = summarization_chain.invoke({"chat_history": stored_messages})
with threading.Lock():
history.clear()
history.add_user_message("Our current conversation summary till now")
history.add_message(summary_message)
history_summarized_time = time.time() - start_time
logging.info(f"Chat History summarized in {history_summarized_time:.2f} seconds")
return True
except Exception as e:
logging.error(f"An error occurred while summarizing messages: {e}", exc_info=True)
return False
def create_graph_chain(model, graph):
try:
logging.info(f"Graph QA Chain using LLM model: {model}")
cypher_llm, model_name = get_llm(model)
qa_llm, model_name = get_llm(model)
graph_chain = GraphCypherQAChain.from_llm(
cypher_llm=cypher_llm,
qa_llm=qa_llm,
validate_cypher=True,
graph=graph,
# verbose=True,
allow_dangerous_requests=True,
return_intermediate_steps=True,
top_k=3
)
logging.info("GraphCypherQAChain instance created successfully.")
return graph_chain, qa_llm, model_name
except Exception as e:
logging.error(f"An error occurred while creating the GraphCypherQAChain instance. : {e}")
def get_graph_response(graph_chain, question):
try:
cypher_res = graph_chain.invoke({"query": question})
response = cypher_res.get("result")
cypher_query = ""
context = []
for step in cypher_res.get("intermediate_steps", []):
if "query" in step:
cypher_string = step["query"]
cypher_query = cypher_string.replace("cypher\n", "").replace("\n", " ").strip()
elif "context" in step:
context = step["context"]
return {
"response": response,
"cypher_query": cypher_query,
"context": context
}
except Exception as e:
logging.error(f"An error occurred while getting the graph response : {e}")
def process_graph_response(model, graph, question, messages, history):
try:
graph_chain, qa_llm, model_version = create_graph_chain(model, graph)
graph_response = get_graph_response(graph_chain, question)
ai_response_content = graph_response.get("response", "Something went wrong")
ai_response = AIMessage(content=ai_response_content)
messages.append(ai_response)
# summarize_and_log(history, messages, qa_llm)
# summarization_thread = threading.Thread(target=summarize_and_log, args=(history, messages, qa_llm))
# summarization_thread.start()
logging.info("Summarization thread started.")
metric_details = {"question": question, "contexts": graph_response.get("context", ""),
"answer": ai_response_content}
result = {
"session_id": "",
"message": ai_response_content,
"info": {
"model": model_version,
"cypher_query": graph_response.get("cypher_query", ""),
"context": graph_response.get("context", ""),
"mode": "graph",
"response_time": 0,
"metric_details": metric_details,
},
"user": "chatbot"
}
return result
except Exception as e:
logging.exception(f"Error processing graph response at {datetime.now()}: {str(e)}")
return {
"session_id": "",
"message": "Something went wrong",
"info": {
"model": model_version,
"cypher_query": "",
"context": "",
"mode": "graph",
"response_time": 0,
"error": f"{type(e).__name__}: {str(e)}"
},
"user": "chatbot"
}
def create_neo4j_chat_message_history(graph, session_id, write_access=True):
"""
Creates and returns a Neo4jChatMessageHistory instance.
"""
try:
if write_access:
history = Neo4jChatMessageHistory(
graph=graph,
session_id=session_id
)
return history
history = get_history_by_session_id(session_id)
return history
except Exception as e:
logging.error(f"Error creating Neo4jChatMessageHistory: {e}")
raise
def get_chat_mode_settings(mode, settings_map=CHAT_MODE_CONFIG_MAP):
default_settings = settings_map[CHAT_DEFAULT_MODE]
try:
chat_mode_settings = settings_map.get(mode, default_settings)
chat_mode_settings["mode"] = mode
logging.info(f"Chat mode settings: {chat_mode_settings}")
except Exception as e:
logging.error(f"Unexpected error: {e}", exc_info=True)
raise
return chat_mode_settings
def QA_RAG(question, session_id, mode, write_access=True):
logging.info(f"Chat Mode: {mode}")
graph = setup_neo4j_connection()
model = MODEL
document_names = '["AM400_User_guide_1_12.md", "AM400_User_guide_13_24.md", "AM400_User_guide_25_36.md", ' \
'"AM400_User_guide_37_56.md", "AM400_User_guide_57_63.md", "AM400_User_guide_64_90.md",' \
'"AM400_User_guide_90_104.md", "AM400_User_guide_105_114.md", "AM400_User_guide_115_140.md",' \
'"AM400_User_guide_141_153.md", "AM400_User_guide_154_163.md", "AM400_User_guide_164_176.md",' \
'"AM400_User_guide_177_178.md", "AM400_User_guide_179_211.md", "AM400_User_guide_212_228.md",' \
'"AM400_User_guide_229_240.md"]'
history = create_neo4j_chat_message_history(graph, session_id, write_access)
messages = history.messages
user_question = HumanMessage(content=question)
messages.append(user_question)
if mode == CHAT_GRAPH_MODE:
result = process_graph_response(model, graph, question, messages, history)
else:
chat_mode_settings = get_chat_mode_settings(mode=mode)
document_names = list(map(str.strip, json.loads(document_names)))
result = process_chat_response(messages, history, question, model, graph, document_names,
chat_mode_settings)
result["session_id"] = session_id
# print('result', result)
return result['message']
print("Starting script...")
def validate_video(video_path):
"""
Validate video file and check if frames can be extracted.
Returns tuple of (is_valid, message, num_frames)
"""
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return False, "Failed to open video file", 0
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f"Video properties: {width}x{height} @ {fps}fps, {total_frames} frames")
if total_frames < 1:
return False, "Video has no frames", 0
if fps <= 0:
return False, "Invalid video FPS", 0
# Try to read first frame
ret, frame = cap.read()
if not ret or frame is None:
return False, "Failed to read first frame", 0
# Try to read last frame
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - 1)
ret, frame = cap.read()
if not ret or frame is None:
return False, "Failed to read last frame", 0
cap.release()
return True, "Video validated successfully", total_frames
except Exception as e:
return False, f"Error validating video: {str(e)}", 0
def make_single_video_dataloader(video_path, config):
"""Create a dataloader for a single video using CSV file"""
print(f"Creating data loader for video: {video_path}")
try:
# Validate video first
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("Failed to open video file")
return None
# Get video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
cap.release()
if total_frames < 1:
print("Video has no frames")
return None
# Create temporary directory for CSV
# with tempfile.TemporaryDirectory() as temp_dir:
# Create CSV file listing the video
# csv_path = os.path.join(temp_dir, 'video_list.csv')
csv_path = 'video_list.csv'
with open(csv_path, 'w') as f:
# Write single line with path and label separated by space
f.write(f"{video_path} 0\n")
print(f"Created CSV at: {csv_path}")
data = pd.read_csv(csv_path, header=None, delimiter=" ")
samples = list(data.values[:, 0])
print('samples', samples)
labels = list(data.values[:, 1])
print('labels', labels)
transform = make_transforms(
training=False,
num_views_per_clip=config['num_views_per_segment'],
random_horizontal_flip=False,
random_resize_aspect_ratio=(0.75, 4 / 3),
random_resize_scale=(0.08, 1.0),
reprob=0.25,
auto_augment=True,
motion_shift=False,
crop_size=config['resolution'],
)
print(f"Creating data loader with CSV: {csv_path}")
try:
data_loader, info = init_data(
data='VideoDataset',
root_path=[csv_path],
transform=transform,
batch_size=1,
world_size=1,
rank=0,
clip_len=config['frames_per_clip'],
frame_sample_rate=config['frame_step'],
num_clips=config['num_segments'],
allow_clip_overlap=True,
num_workers=1,
copy_data=False,
drop_last=False
)
print(f"Data loader info: {info}")
except Exception as e:
print("\n=== Data Loader Creation Error ===")
print(f"Error type: {type(e).__name__}")
print(f"Error message: {str(e)}")
print("\nFull traceback:")
traceback.print_exc(file=sys.stdout)
return None
return data_loader
except Exception as e:
print(f"Error creating data loader: {str(e)}")
return None
class VideoClassifier:
def __init__(self, config):
print("Initializing VideoClassifier...")
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {self.device}")
self.config = config
self.encoder, self.classifier = self.load_models()
self.transform = make_transforms(
training=False,
num_views_per_clip=config['num_views_per_segment'],
crop_size=config['resolution']
)
def load_checkpoint(self, device, r_path, classifier, opt, scaler):
"""
Load checkpoint for classifier, optimizer and scaler.
Args:
device: torch device to load the model onto
r_path: path to checkpoint file
classifier: classifier model to load weights into
opt: optimizer to load state into
scaler: gradient scaler for mixed precision training (can be None)
Returns:
tuple: (classifier, optimizer, scaler, epoch)
"""
try:
print(f'Loading checkpoint from: {r_path}')
checkpoint = torch.load(r_path, map_location=torch.device('cpu'))
# Get epoch from checkpoint
epoch = checkpoint.get('epoch', 0)
print(f'Checkpoint from epoch: {epoch}')
# Clean and load classifier state dict
if 'classifier' in checkpoint:
state_dict = checkpoint['classifier']
# Remove 'module.' prefix if present
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
try:
msg = classifier.load_state_dict(new_state_dict)
print(f'Loaded classifier state with message: {msg}')
except Exception as e:
print(f'Error loading classifier state dict: {str(e)}')
raise
else:
print('No classifier state found in checkpoint')
raise KeyError('No classifier state in checkpoint')
# Load optimizer state if present
if 'opt' in checkpoint and opt is not None:
try:
opt.load_state_dict(checkpoint['opt'])
print('Loaded optimizer state')
except Exception as e:
print(f'Error loading optimizer state: {str(e)}')
raise
else:
print('No optimizer state found in checkpoint')
# Load scaler state if present and scaler is provided
if scaler is not None and 'scaler' in checkpoint:
try:
scaler.load_state_dict(checkpoint['scaler'])
print('Loaded scaler state')
except Exception as e:
print(f'Error loading scaler state: {str(e)}')
raise
elif scaler is not None:
print('No scaler state found in checkpoint')
# Clean up memory
del checkpoint
torch.cuda.empty_cache()
return classifier, opt, scaler, epoch
except FileNotFoundError:
print(f'Checkpoint file not found: {r_path}')
return classifier, opt, scaler, 0
except Exception as e:
print(f'Error loading checkpoint: {str(e)}')
print(f'Traceback: {traceback.format_exc()}')
return classifier, opt, scaler, 0
def load_models(self):
print("Loading models...")
# Initialize encoder
encoder = init_model(
crop_size=self.config['resolution'],
device=self.device,
pretrained=self.config['pretrained_path'],
model_name=self.config['model_name'],
patch_size=self.config['patch_size'],
tubelet_size=self.config['tubelet_size'],
frames_per_clip=self.config['frames_per_clip'],
uniform_power=self.config['uniform_power'],
checkpoint_key=self.config['checkpoint_key'],
use_SiLU=self.config['use_silu'],
tight_SiLU=self.config['tight_silu'],
use_sdpa=self.config['use_sdpa'])
if self.config['frames_per_clip'] == 1:
# Process each frame independently and aggregate
encoder = FrameAggregation(encoder).to(self.device)
else:
# Process each video clip independently and aggregate
encoder = ClipAggregation(
encoder,
tubelet_size=self.config['tubelet_size'],
attend_across_segments=self.config['attend_across_segments']
).to(self.device)
print("Loading pretrained weights...")
# Load pretrained weights
checkpoint = torch.load(self.config['pretrained_path'], map_location=self.device)
encoder.load_state_dict(checkpoint[self.config['checkpoint_key']], strict=False)
encoder.eval()
print("Initializing classifier...")
# Initialize and load classifier
classifier = AttentiveClassifier(
embed_dim=encoder.embed_dim,
num_heads=encoder.num_heads,
depth=1,
num_classes=self.config['num_classes']
).to(self.device)
optimizer, scaler, scheduler, wd_scheduler = init_opt(
classifier=classifier,
wd=0.01,
start_lr=0.001,
ref_lr=0.001,
final_lr=0.0,
iterations_per_epoch=1,
warmup=0.,
num_epochs=20,
use_bfloat16=True)
# classifier = DistributedDataParallel(classifier, static_graph=True)
classifier, optimizer, scaler, epoch = self.load_checkpoint(
device=self.device,
r_path=self.config['classifier_path'],
classifier=classifier,
opt=optimizer,
scaler=scaler
)
# classifier_checkpoint = torch.load(self.config['classifier_path'], map_location=self.device)
# classifier.load_state_dict(classifier_checkpoint['classifier'])
classifier.eval()
return encoder, classifier
def predict(self, video_path):
"""
Predict the class of a video.
"""
print(f"Starting prediction for video: {video_path}")
try:
# Check if video file exists
if not os.path.exists(video_path):
print(f"Error: Video file not found: {video_path}")
return "Error: Video file not found", 0.0
# Check video file size
file_size = os.path.getsize(video_path)
if file_size == 0:
print("Error: Video file is empty")
return "Error: Empty video file", 0.0
print(f"Video file size: {file_size} bytes")
# Initialize data loader
data_loader = make_single_video_dataloader(video_path, self.config)
if data_loader is None:
return "Error: Failed to create data loader", 0.0
# Process the video
with torch.no_grad():
try:
# Get first batch
batch = next(iter(data_loader))
# Prepare clips
clips = [
[dij.to(self.device, non_blocking=True) for dij in di]
for di in batch[0]
]
clip_indices = [d.to(self.device, non_blocking=True) for d in batch[2]]
print("Running encoder forward pass...")
encoder_outputs = self.encoder(clips, clip_indices)
print("Running classifier forward pass...")
if self.config['attend_across_segments']:
outputs = [self.classifier(o) for o in encoder_outputs]
probs = sum([F.softmax(o, dim=1) for o in outputs]) / len(outputs)
else:
outputs = [[self.classifier(ost) for ost in os] for os in encoder_outputs]
probs = sum([sum([F.softmax(ost, dim=1) for ost in os])
for os in outputs]) / len(outputs) / len(outputs[0])
# Get prediction and confidence
pred_class = probs.argmax(dim=1).item()
confidence = probs.max(dim=1).values.item()
class_name = self.config['class_names'][pred_class]
print(f'Predicted class: {class_name} (index: {pred_class})')
print(f'Confidence: {confidence:.4f}')
return class_name, confidence
except StopIteration:
print("Error: No data in data loader")
return "Error: Failed to load video frames", 0.0
except Exception as e:
print(f"Error during prediction: {str(e)}")
return f"Error during prediction: {str(e)}", 0.0
except Exception as e:
print(f"Error in prediction pipeline: {str(e)}")
return f"Error in prediction pipeline: {str(e)}", 0.0
def process_video(classifier, video_path):
print("Processing video...")
if video_path is None:
return "No video uploaded", 0.0
mp4_path = None # Ensure mp4_path is always defined
try:
if isinstance(video_path, str):
print(f"Processing video from path: {video_path}")
# Check if the file is in .webm format
if video_path.endswith('.webm'):
print("Converting .webm to .mp4 using ffmpeg...")
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
mp4_path = temp_file.name
# Use ffmpeg to convert the file
command = ['ffmpeg', '-y', '-i', video_path, '-c:v', 'libx264', '-preset', 'fast', '-crf', '23', '-c:a',
'aac', mp4_path]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode != 0:
print(f"FFmpeg error: {result.stderr.decode()}")
return "Error during video conversion", 0.0
print(f"Converted video saved to: {mp4_path}")
video_path = mp4_path # Update video_path to use the converted file
# Process the video file
try:
return classifier.predict(video_path)
finally:
# Clean up the temporary converted file if it was created
if mp4_path and os.path.exists(mp4_path):
os.unlink(mp4_path)
else:
print(f"Error: Unexpected video input type: {type(video_path)}")
return "Error: Invalid video format", 0.0
except Exception as e:
print(f"Error processing video: {str(e)}")
return f"Error: {str(e)}", 0.0
def process_highlights_and_images(qa_reuslt_pure):
if not qa_reuslt_pure:
return "", [], {}
response = qa_reuslt_pure
# Process summary and get cleaned lines
cleaned_lines = []
print('qa_reuslt_pure', qa_reuslt_pure)
for line in qa_reuslt_pure.strip().split('\n'):
cleaned_line = line.replace("**", "")
cleaned_lines.append(cleaned_line)
# Setup Neo4j connection
graph = setup_neo4j_connection()
try:
# Process each line for images
captions = []
images = []
similarities = []
for line in cleaned_lines:
line_embedding = EMBEDDING_FUNCTION.embed_query(line)
query = """
WITH $line_embedding AS line_embedding
MATCH (f:Figure)
WHERE vector.similarity.cosine(f.embedding, line_embedding) > $similarity_threshold
RETURN f, vector.similarity.cosine(f.embedding, line_embedding) AS similarity
ORDER BY similarity DESC
LIMIT 1
"""
result = graph.query(query, {
"line_embedding": line_embedding,
"similarity_threshold": 0.76
})
for record in result:
node = record["f"]
similarity = record["similarity"]
if 'img_ref' in node:
images.append(node['img_ref'])
captions.append(node['description'])
similarities.append(similarity)
# Generate highlighted response
highlighted_response = response
if captions:
response_phrases = response.split('. ')
response_embeddings = [EMBEDDING_FUNCTION.embed_query(phrase) for phrase in response_phrases]
caption_embeddings = [EMBEDDING_FUNCTION.embed_query(caption) for caption in captions]
similarity_matrix = cosine_similarity(response_embeddings, caption_embeddings)
highlighted_response = ""
for i, phrase in enumerate(response_phrases):
max_similarity = np.max(similarity_matrix[i])
is_last_phrase = (i == len(response_phrases) - 1)
if max_similarity >= 0.75:
highlighted_response += f"<mark style='background-color: yellow;'>{phrase}</mark>"
else:
highlighted_response += phrase
if not is_last_phrase:
highlighted_response += ". "
# Prepare image information
image_info = [
{"caption": caption, "similarity": f"{sim:.3f}"}
for caption, sim in zip(captions, similarities)
]
return highlighted_response, images, image_info
finally:
if hasattr(graph, 'close'):
graph.close()
class AudioAgent:
def __init__(self, api_key: str):
self.client = OpenAI(api_key=api_key)
self.temp_dir = tempfile.mkdtemp() # Create a temporary directory for files
print('temp dir', self.temp_dir)
self.conversation_state = {
'current_stage': None,
'video_prediction': None,
'video_confidence': None,
'generated_question': None,
'pure_response': None,
'highlighted_response': None,
'gallery': None,
'image_infos': None,
'webcam_active': False
}
def _get_temp_path(self, filename: str) -> str:
"""Generate a temporary file path"""
temp_path = os.path.join(self.temp_dir, filename)
print('temp_path', temp_path)
return temp_path
def transcribe_audio(self, audio_path: str) -> str:
"""Transcribe audio using Whisper"""
try:
if not os.path.isfile(audio_path):
print(f"Invalid audio path: {audio_path}")
return ""
with open(audio_path, "rb") as audio_file:
transcript = self.client.audio.transcriptions.create(
model="whisper-1",
file=audio_file
)
print('transcript text', transcript.text)
return transcript.text
except Exception as e:
print(f"Error in transcription: {e}")
return ""
def text_to_speech(self, text: str, filename: str) -> str:
"""Convert text to speech using OpenAI's TTS"""
try:
output_path = self._get_temp_path(filename)
print('output_path', output_path)
response = self.client.audio.speech.create(
model="tts-1",
voice="alloy",
input=text
)
response.stream_to_file(output_path)
return output_path
except Exception as e:
print(f"Error in TTS conversion: {e}")
return ""
def determine_action(self, instruction: str) -> str:
"""Determine what action the user wants to take"""
messages = [
{
"role": "system",
"content": f"""Determine what action the user wants to take based on the current conversation state:
Current state: {self.conversation_state['current_stage']}
Video prediction: {self.conversation_state['video_prediction']}
Possible actions:
- analyze_video: If they want to analyze a video
- ask_question: If they are asking their own question
- answer_questions: If they want to answer questions
- generate_highlights: If they want to process answers for highlights
- unclear: If the request is unclear
Return only one of these action names."""
},
{"role": "user", "content": str(instruction)}
]
try:
response = self.client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=messages
)
return response.choices[0].message.content.strip()
except Exception as e:
print(f"Error determining action: {e}")
return "unclear"
def extract_question(self, instruction: str) -> str:
"""Extract the actual question from user's voice input"""
messages = [
{
"role": "system",
"content": "Extract the actual question from the user's input. Return only the question itself."
},
{"role": "user", "content": str(instruction)}
]
try:
response = self.client.chat.completions.create(
model="gpt-4-turbo-preview",
messages=messages
)
return response.choices[0].message.content.strip()
except Exception as e:
print(f"Error extracting question: {e}")
return instruction
def process_request(self, instruction: str, **kwargs) -> Dict[str, Any]:
"""Process user instruction based on conversation state"""
try:
action = self.determine_action(instruction)
results = {}
final_audio_path = None
# if action == "take_video":
# self.conversation_state['webcam_active'] = True
# return {
# "results": {"message": "Activating webcam..."},
# "action": "take_video",
# "webcam_active": True
# }
if action == "analyze_video":
video_input = kwargs.get('webcam_input')
print(f"Video input received: {video_input}, type: {type(video_input)}")
video_path = None
# Handle different video input formats
if isinstance(video_input, tuple):
# Webcam recording format
video_path = video_input[0] if video_input else None
elif isinstance(video_input, str):
# Direct file path
video_path = video_input
elif isinstance(video_input, dict) and 'name' in video_input:
# Uploaded file format
video_path = video_input['name']
print(f"Processed video path: {video_path}")
if video_path and os.path.isfile(video_path):
classifier = kwargs.get('classifier')
if classifier:
video_result = process_video(classifier, video_path)
self.conversation_state['video_prediction'] = video_result[0]
self.conversation_state['video_confidence'] = video_result[1]
self.conversation_state['current_stage'] = 'video_analyzed'
results.update({
'video_prediction': video_result[0],
'video_confidence': video_result[1]
})
final_audio_path = self.text_to_speech(
f"Video analysis complete. Detected action: {video_result[0]} with confidence {video_result[1]:.2f}. Would you like to ask questions about this activity?",
"response_video.mp3"
)
elif action == "ask_question":
if self.conversation_state['video_prediction']:
# Extract the actual question from user's voice input
custom_question = self.extract_question(instruction)
self.conversation_state['generated_question'] = custom_question
self.conversation_state['current_stage'] = 'questions_generated'
results['generated_question'] = custom_question
final_audio_path = self.text_to_speech(
f"I understand your question: '{custom_question}'. Would you like me to answer it?",
"response_custom_question.mp3"
)
else:
return {"error": "Please analyze a video first"}
elif action == "answer_questions":
if self.conversation_state['generated_question']:
pure_response = QA_RAG(
self.conversation_state['generated_question'],
kwargs.get('session_id', 1),
kwargs.get('mode', 'CHAT_VECTOR_GRAPH_MODE')
)
self.conversation_state['pure_response'] = pure_response
self.conversation_state['current_stage'] = 'questions_answered'
results.update({
'pure_response': pure_response
})
final_audio_path = self.text_to_speech(
f"Sure! I've answered the questions. {pure_response}. Would you like me to generate highlights and related images?",
"response_qa.mp3"
)
else:
return {"error": "Please generate questions first"}
elif action == "generate_highlights":
if self.conversation_state['pure_response']:
highlighted_response, gallery, image_infos = process_highlights_and_images(
self.conversation_state['pure_response']
)
self.conversation_state.update({
'highlighted_response': highlighted_response,
'gallery': gallery,
'image_infos': image_infos,
'current_stage': 'highlights_generated'
})
results.update({
'highlighted_response': highlighted_response,
'gallery': gallery,
'image_infos': image_infos
})
final_audio_path = self.text_to_speech(
"I've generated highlights and retrieved related images. Is there anything else you'd like to know?",
"response_highlights.mp3"
)
else:
return {"error": "Please answer questions first"}
else:
suggestion = ""
if not self.conversation_state['current_stage']:
suggestion = "You can start by asking me to analyze a video."
elif self.conversation_state['current_stage'] == 'video_analyzed':
suggestion = "You can ask me to generate questions about the activity."
elif self.conversation_state['current_stage'] == 'questions_generated':
suggestion = "You can ask me to answer the questions."
elif self.conversation_state['current_stage'] == 'questions_answered':
suggestion = "You can ask me to generate highlights and images."
return {
"error": f"I'm not sure what you want me to do. {suggestion}"
}
return {
"results": results,
"audio_path": final_audio_path,
"action": action,
**results
}
except Exception as e:
print(f"Error in process_request: {e}")
return {"error": f"Error processing request: {e}"}
def process_audio_input(
agent,
audio_path: str,
webcam_input: str = None,
classifier=None,
mode: str = 'CHAT_VECTOR_GRAPH_MODE',
session_id: int = 1
) -> Tuple[Dict[str, Any], str]:
"""Main function to handle audio input and generate audio output"""
# agent = AudioAgent(api_key=os.getenv("OPENAI_API_KEY"))
# Transcribe audio to text
transcribed_text = agent.transcribe_audio(audio_path)
if not transcribed_text:
return {"error": "Error in transcription"}, ""
# Process the transcribed text through the sequential pipeline
result = agent.process_request(
transcribed_text,
webcam_input=webcam_input,
classifier=classifier,
mode=mode,
session_id=session_id
)
if "error" in result:
return result, ""
return result, result.get("audio_path", "")
def main():
print("Starting main function...")
# from huggingface_hub import snapshot_download
# snapshot_download(repo_id="IELTS8/pretrained_model")
import os
import urllib.request
# Define file paths and URLs
models = {
"pretrained_model/vith16.pth.tar": "https://dl.fbaipublicfiles.com/jepa/vith16/vith16.pth.tar",
"pretrained_model/v-jepa-224-classifier/jepa-latest.pth.tar": "https://huggingface.co/IELTS8/pretrained_model/resolve/main/v-jepa-224-classifier/jepa-latest.pth.tar"
}
# Function to ensure directory exists
def ensure_directory_exists(file_path):
directory = os.path.dirname(file_path)
if not os.path.exists(directory):
os.makedirs(directory)
# Function to download file if it does not exist
def download_file(file_path, url):
if not os.path.exists(file_path):
print(f"Downloading {file_path} from {url}...")
ensure_directory_exists(file_path)
urllib.request.urlretrieve(url, file_path)
print(f"Downloaded: {file_path}")
else:
print(f"File already exists: {file_path}")
def update_video(video_path):
return video_path
# Check and download models
for file_path, url in models.items():
download_file(file_path, url)
# from huggingface_hub import hf_hub_download
# hf_hub_download(repo_id="IELTS8/pretrained_model")
model_config = {
'model_name': 'vit_huge',
'resolution': 224,
'patch_size': 16,
'frames_per_clip': 16,
'tubelet_size': 2,
'uniform_power': True,
'use_sdpa': True,
'use_silu': False,
'tight_silu': False,
'num_classes': 11,
'pretrained_path': 'pretrained_model/vith16.pth.tar',
'classifier_path': 'pretrained_model/v-jepa-224-classifier/jepa-latest.pth.tar',
'checkpoint_key': 'target_encoder',
'num_segments': 2,
'num_views_per_segment': 3,
'frame_step': 3,
'attend_across_segments': True,
'class_names': [
'clean the build plate', 'clean the recoater', 'insert the build plate screw',
'measure the build plate thickness', 'open the process chamber door',
'operate valves of the large safe change filter',
'place build plate inside the process chamber',
'place recoater into the process chamber', 'press reset button on the control panel',
'press stop button on the control panel',
'wipe laser window inside the process chamber'
]
}
world_size, rank = init_distributed()
print(f'Initialized (rank/world-size) {rank}/{world_size}')
print("Creating classifier...")
classifier = VideoClassifier(model_config)
# print("Setting up Gradio interface...")
# with gr.Blocks(theme=gr.themes.Soft()) as demo:
# gr.Markdown("# 🏭 Sim2Real: Action Recognition (Synthetic Data) & Knowledge Guidance (Knowledge Graph)")
# Create a custom theme
custom_theme = gr.themes.Soft().set(
body_background_fill="rgba(255, 255, 255, 0.3)", # Main background with 70% opacity
block_background_fill="rgba(255, 255, 255, 0.3)", # Component blocks with 50% opacity
panel_background_fill="rgba(255, 255, 255, 0.3)", # Panels with 30% opacity
# Additional theme customizations for better visibility
block_border_width="1px",
block_border_color="rgba(128, 128, 128, 0.3)",
block_radius="8px",
# Text colors for better contrast against transparent background
block_title_text_color="rgba(0, 0, 0, 0.8)",
block_label_text_color="rgba(0, 0, 0, 0.7)",
body_text_color="rgba(0, 0, 0, 0.8)"
)
print("Setting up Gradio interface...")
sample_videos = ["example_1.mp4", "example_2.mp4", "example_3.mp4", "example_4.mp4", "example_5.mp4"]
with gr.Blocks(theme=custom_theme, css="""
.gradio-container {
background: rgba(255, 255, 255, 0.7) !important;
}
.block {
background: rgba(255, 255, 255, 0.5) !important;
backdrop-filter: blur(10px);
}
.panel {
background: rgba(255, 255, 255, 0.3) !important;
}
""") as demo:
gr.Markdown("# 🏭 Sim2Real: Action Recognition (Synthetic Data) & Knowledge Guidance (Knowledge Graph)")
# Add state for webcam control
last_processed_time = gr.State(time.time())
is_processing = gr.State(False)
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("### 🎬 Operation Recognition")
with gr.Row():
with gr.Column(scale=3):
video_input = gr.Video(
label="Upload Video",
sources=["upload", "webcam"],
height=400,
width=600
)
with gr.Column(scale=1):
# Force examples to render in the right column
with gr.Group():
gr.Examples(
examples=sample_videos,
inputs=video_input,
examples_per_page=5, # Show all examples at once
label="Sample Videos"
)
with gr.Row():
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label="Speak your request",
streaming=False
)
with gr.Row():
text_output = gr.Textbox(label="Response")
audio_output = gr.Audio(
label="Audio Response",
autoplay=True
)
status = gr.Textbox(label="Status", value="Ready")
# Hidden components for storing intermediate results
video_prediction = gr.Textbox(visible=False)
video_confidence = gr.Number(visible=False)
generated_question = gr.Textbox(visible=False)
with gr.Column(scale=3):
gr.Markdown("### 🔍 Knowledge Guidance")
with gr.Row():
mode = gr.Dropdown(
choices=[
'CHAT_VECTOR_GRAPH_MODE',
'CHAT_GRAPH_MODE',
'CHAT_VECTOR_MODE',
'CHAT_FULLTEXT_MODE',
'CHAT_VECTOR_GRAPH_FULLTEXT_MODE'
],
label="Select Mode",
value='CHAT_VECTOR_GRAPH_MODE',
visible=False,
)
session_id = gr.Number(
label="Session ID",
value=1,
visible=False
)
pure_response = gr.Markdown(visible=False)
highlighted_response = gr.HTML(label="Highlighted Response")
with gr.Row():
gallery = gr.Gallery(label="Related Images", height=150)
image_infos = gr.JSON(label="Image Information", open=False)
agent = AudioAgent(api_key=os.getenv("OPENAI_API_KEY"))
def process_with_status(audio_path, last_time, is_proc, video_path, mode_value, session_id_value):
# Debug logging
print("=== Input Debug ===")
print(f"Audio path: {audio_path}, type: {type(audio_path)}")
print(f"Video path: {video_path}, type: {type(video_path)}")
print(f"Mode value: {mode_value}")
print(f"Session ID: {session_id_value}")
# Video path handling with debug
video_file = None
if isinstance(video_path, tuple) and len(video_path) > 0:
video_file = video_path[0]
print(f"Extracted video file from tuple: {video_file}")
elif isinstance(video_path, str):
video_file = video_path
print(f"Using video path directly: {video_file}")
else:
print(f"Unexpected video_path type: {type(video_path)}, value: {video_path}")
current_time = time.time()
# Check if enough time has passed (3 seconds minimum)
if current_time - last_time < 4 or is_proc:
return {
status: "Processing previous request...",
last_processed_time: last_time,
is_processing: is_proc
}
# Set processing flag
is_proc = True
# Process the audio
if audio_path is None:
return {
status: "No audio detected",
last_processed_time: current_time,
is_processing: False
}
result, audio_path = process_audio_input(
agent=agent,
audio_path=audio_path,
webcam_input=video_path,
classifier=classifier,
mode=mode_value,
session_id=session_id_value
)
print("=== Output Debug ===")
print(f"Result: {result}")
print(f"Audio response path: {audio_path}")
# # Handle webcam activation
# if result.get("action") == "take_video":
# return {
# video_input: gr.update(interactive=True),
# status: "Webcam activated. You can now record a video.",
# text_output: "Webcam activated"
# }
if "error" in result:
return result["error"], None, "", 0, [], "", "", [], {}, f"Error: {result['error']}"
# status.update(value="Ready")
return (
result, # text_output
audio_path, # audio_output
result.get("video_prediction", ""), # video_prediction
result.get("video_confidence", 0), # video_confidence
result.get("generated_question", []), # generated_question
result.get("pure_response", ""), # pure_response
result.get("highlighted_response", ""), # highlighted_response
result.get("gallery", []), # gallery
result.get("image_infos", {}), # image_infos,
"Ready" # status
)
audio_input.change(
fn=process_with_status,
inputs=[
audio_input,
last_processed_time,
is_processing,
video_input,
mode,
session_id
],
outputs=[
text_output,
audio_output,
video_prediction,
video_confidence,
generated_question,
pure_response,
highlighted_response,
gallery,
image_infos,
status
]
)
print("Launching interface...")
return demo
if __name__ == "__main__":
print("Script started")
demo = main()
demo.launch(share=True)