import os import json import base64 import argparse import time import re import traceback from datetime import datetime from functools import partial from openai import AzureOpenAI, OpenAI from volcenginesdkarkruntime import Ark import concurrent.futures from tqdm import tqdm import torch from transformers import AutoModel, AutoProcessor from torch.nn.functional import cosine_similarity # MODIFIED: Added imports for multiprocessing and uuid import multiprocessing import uuid # --- Configuration for SigLIP Model --- # MODIFIED: Updated to the local model path SIGLIP_MODEL_ID = ( "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384" ) # --- MODIFIED: Updated System Prompt explaining the two tools with examples --- AGENT_SYSTEM_PROMPT = """ You are an intelligent AI assistant specialized in video question answering. Your task is to answer a multiple-choice question based on a video by strategically retrieving and analyzing its frames. You have two tools to retrieve frames. Both return images directly. 1. `get_frames_by_id(frame_ids)`: Retrieves frames using their specific numerical IDs. Use this when the question provides direct temporal clues or when you need to view specific frames identified by another tool. * **Example Use Case:** For a question like "What happens at the 1 minute 30 second mark?", you can calculate the approximate frame ID and use this tool to see the visual. * **Example Use Case:** For "Describe the action in frame 550.", you would call this tool with `frame_ids=[550]`. 2. `get_frames_by_similarity(query)`: Searches the entire video for frames that visually match a text description and returns the top 5 most relevant frames directly. Use this for content-based questions where the timing is unknown. * **Example Use Case:** For a question like "What color is the main character's car?", you would use this tool with a query like "the main character's car". * **Example Use Case:** For "Find the scene where a band is playing on stage", you would use the query "a band playing on stage". Your strategy must be efficient: 1. **Analyze the Query:** First, determine if the question is temporal/logical (better for `get_frames_by_id`) or content-based (requires `get_frames_by_similarity`). 2. **Retrieve & Analyze:** Call the most appropriate tool. Analyze the returned frames to form a hypothesis. 3. **Iterate:** If you need more information, refine your search query for the similarity tool or calculate new frame IDs for the ID tool and call again. 4. **Final Answer:** Once you have gathered enough visual evidence, provide your step-by-step reasoning and then the final answer in the specified JSON format. Do not guess. Your output should follow this format exactly: ```json {"answer": "X"} ``` Do not include any other text after the JSON block. """ # Tool Schemas GET_FRAMES_BY_ID_TOOL_SCHEMA = { "type": "function", "function": { "name": "get_frames_by_id", "description": "Retrieves specific video frames by their numerical IDs to get visual information.", "parameters": { "type": "object", "properties": { "frame_ids": { "type": "array", "items": {"type": "integer"}, "description": "A list of up to 10 frame numbers to retrieve.", }, }, "required": ["frame_ids"], }, }, } GET_FRAMES_BY_SIMILARITY_TOOL_SCHEMA = { "type": "function", "function": { "name": "get_frames_by_similarity", "description": "Searches for and retrieves the top 5 most visually relevant frames for a given text query. Use this to locate visual content when frame numbers are unknown.", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "A concise text description of the visual content to search for (e.g., 'a person playing piano').", }, }, "required": ["query"], }, }, } def parse_arguments(): """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Agentic Video QA with Hybrid Frame Retrieval" ) parser.add_argument( "--target-model", "-tm", type=str, required=True, help="Model to evaluate." ) parser.add_argument( "--frames-path", "-fp", type=str, required=True, help="Base directory for video frames.", ) parser.add_argument( "--data-file", "-df", type=str, required=True, help="Path to the evaluation dataset.", ) parser.add_argument( "--embeddings-path", "-ep", type=str, required=True, help="Directory with pre-computed frame embeddings.", ) parser.add_argument( "--max-retry-times", "-mr", type=int, default=10, help="Max retries for API calls.", ) parser.add_argument( "--pool-processes", "-pp", type=int, default=20, help="Number of parallel processes.", ) parser.add_argument("--base_url", type=str, required=True, help="API endpoint URL.") parser.add_argument("--api_key", type=str, required=True, help="API key.") return parser.parse_args() def save_json_file(data, output_file): """Saves data to a JSON file.""" with open(output_file, "w", encoding="utf-8") as f: json.dump(data, f, indent=4) def extract_json_from_response(response): """Extracts a JSON object from a model's response string.""" if not response: return None match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL) if match: try: return json.loads(match.group(1)) except (json.JSONDecodeError, IndexError): return None return None def calculate_metrics(results): """Calculates accuracy and other metrics from evaluation results.""" valid_results = [r for r in results if "error" not in r] total_samples = len(valid_results) if total_samples == 0: return { "total_samples": 0, "answered_samples": 0, "correct_answers": 0, "accuracy": 0.0, } answered_samples = sum( 1 for x in valid_results if x.get("model_answer") is not None ) correct_answers = sum(1 for x in valid_results if x.get("is_correct")) accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0 return { "total_samples": total_samples, "answered_samples": answered_samples, "correct_answers": correct_answers, "accuracy": accuracy, } def call_single_model(client, messages, model, item_id, max_retry_times, tools=None): """Makes a single API call with retry logic and tool support.""" params = {"model": model, "messages": messages, "max_tokens": 4096} if tools: params["tools"] = tools params["tool_choice"] = "auto" for retry in range(max_retry_times): try: completion = client.chat.completions.create(**params) return completion.choices[0].message except Exception as e: print( f"API Error for item {item_id}: {str(e)}. Retrying ({retry + 1}/{max_retry_times})..." ) if retry == max_retry_times - 1: raise e time.sleep(5) def get_frames_by_id(frame_ids: list, all_frame_paths: list): """Tool implementation: Retrieves and encodes frames from a list of IDs.""" retrieved_frames = [] frame_map = { int(re.search(r"frame_(\d+)\.jpg", os.path.basename(p)).group(1)): p for p in all_frame_paths if re.search(r"frame_(\d+)\.jpg", os.path.basename(p)) } for fid in frame_ids: path = frame_map.get(fid) if path and os.path.exists(path): b64_image = encode_image(path) retrieved_frames.append( { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_image}"}, } ) return retrieved_frames # MODIFIED: This function is now the "client" side of the embedding service. def get_frames_by_similarity( query: str, all_frame_paths: list, precomputed_data: dict, request_queue: multiprocessing.Queue, results_dict: dict, k: int = 5, ): """ Requests a text embedding from the server process, calculates similarity, finds top-k frames, and returns them encoded. """ device = "cuda" if torch.cuda.is_available() else "cpu" frame_filenames = precomputed_data["filenames"] frame_embeddings = precomputed_data["embeddings"].to(device) # 1. Send request to the embedding server process request_id = str(uuid.uuid4()) request_queue.put((request_id, query)) # 2. Wait for the result while request_id not in results_dict: time.sleep(0.05) query_embedding = results_dict.pop(request_id).to(device) # 3. Perform similarity search with the received embedding with torch.no_grad(): similarities = cosine_similarity(query_embedding, frame_embeddings) num_frames_to_select = min(k, len(frame_filenames)) top_k_indices = ( torch.topk(similarities, k=num_frames_to_select, dim=-1) .indices.cpu() .flatten() .numpy() ) top_k_filenames = [frame_filenames[i] for i in top_k_indices] top_k_frame_ids = [ int(re.search(r"frame_(\d+)\.jpg", f).group(1)) for f in top_k_filenames ] retrieved_frames = get_frames_by_id(top_k_frame_ids, all_frame_paths) return retrieved_frames def evaluate_single_item_agentic( data_item, all_frame_paths, embeddings_data, target_model, api_key, base_url, max_retry_times, request_queue, # MODIFIED: Added queue for IPC results_dict, # MODIFIED: Added dict for IPC ): """Evaluates a single item using an agentic loop.""" if "ark" in base_url: client = Ark(base_url=base_url, api_key=api_key) elif "aliyun" in base_url or "127.0.0.1" in base_url: client = OpenAI(api_key=api_key, base_url=base_url) else: client = AzureOpenAI( api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url ) tools = [GET_FRAMES_BY_ID_TOOL_SCHEMA, GET_FRAMES_BY_SIMILARITY_TOOL_SCHEMA] get_frames_by_id_with_context = partial( get_frames_by_id, all_frame_paths=all_frame_paths ) # MODIFIED: Pass the request queue and results dict to the similarity function get_frames_by_similarity_with_context = partial( get_frames_by_similarity, all_frame_paths=all_frame_paths, precomputed_data=embeddings_data, request_queue=request_queue, results_dict=results_dict, ) available_functions = { "get_frames_by_id": get_frames_by_id_with_context, "get_frames_by_similarity": get_frames_by_similarity_with_context, } total_frames = len(all_frame_paths) duration = data_item.get("video_info", {}).get("duration_minutes", 0) * 60 initial_prompt = ( f"The video has {total_frames} frames (ID 1 to {total_frames}) and is {duration:.0f} seconds long. " f"Please answer this question:\n{data_item['question']}" ) messages = [ {"role": "system", "content": AGENT_SYSTEM_PROMPT}, {"role": "user", "content": initial_prompt}, ] response_content = None max_tool_calls = 10 for _ in range(max_tool_calls): response_message = call_single_model( client, messages, target_model, data_item["key"], max_retry_times, tools=tools, ) if response_message is None: return None messages.append(response_message) if response_message.tool_calls: for tool_call in response_message.tool_calls: function_name = tool_call.function.name function_to_call = available_functions.get(function_name) if function_to_call: function_args = json.loads(tool_call.function.arguments) function_response = function_to_call(**function_args) messages.append( { "tool_call_id": tool_call.id, "role": "tool", "name": function_name, "content": json.dumps( { "status": "success", "retrieved_frame_count": len(function_response), } ), } ) user_message_with_frames = [ { "type": "text", "text": f"Here are the {len(function_response)} frames from your call to `{function_name}`.", } ] user_message_with_frames.extend(function_response) messages.append( {"role": "user", "content": user_message_with_frames} ) else: response_content = response_message.content break if response_content is None: final_prompt = "You have reached the maximum number of tool calls. Provide a final answer based on the information gathered so far." messages.append({"role": "user", "content": final_prompt}) final_response = call_single_model( client, messages, target_model, data_item["key"], max_retry_times ) response_content = ( final_response.content if final_response else "Could not determine an answer after max tool calls." ) is_correct = False model_answer_cleaned = None parsed_json = extract_json_from_response(response_content) if parsed_json and "answer" in parsed_json: model_answer_cleaned = str(parsed_json["answer"]).strip().upper() if model_answer_cleaned == data_item["answer"].strip().upper(): is_correct = True return { **data_item, "agent_conversation": [ msg if isinstance(msg, dict) else msg.model_dump() for msg in messages ], "model_reasoning_and_answer": response_content, "model_answer": model_answer_cleaned, "is_correct": is_correct, } def encode_image(image_path): """Encodes an image file to a base64 string.""" with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") # MODIFIED: Function signature updated to accept queues for IPC def process_single_data(data_item, args, request_queue, results_dict): """Main processing function for a single video, executed by a worker.""" item_key = data_item["key"] try: specific_frames_path = os.path.join(args.frames_path, item_key) embedding_file = os.path.join(args.embeddings_path, f"{item_key}.pt") if not os.path.isdir(specific_frames_path): raise FileNotFoundError( f"Frame directory not found: {specific_frames_path}" ) if not os.path.exists(embedding_file): raise FileNotFoundError(f"Embedding file not found: {embedding_file}") all_frame_paths = sorted( [ os.path.join(specific_frames_path, f) for f in os.listdir(specific_frames_path) if f.endswith(".jpg") ], key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)), ) if not all_frame_paths: raise FileNotFoundError(f"No frames found for key '{item_key}'") embeddings_data = torch.load(embedding_file, map_location="cpu") # MODIFIED: Pass queues to the evaluation function result = evaluate_single_item_agentic( data_item, all_frame_paths, embeddings_data, args.target_model, args.api_key, args.base_url, args.max_retry_times, request_queue, results_dict, ) return result except Exception as e: print(f"\nCRITICAL ERROR on key {item_key}: {str(e)}") traceback.print_exc() return { "key": item_key, "uid": data_item.get("uid"), "error": str(e), "traceback": traceback.format_exc(), } def load_test_data(json_file): """Loads the evaluation data from a JSON file.""" try: with open(json_file, "r", encoding="utf-8") as f: return json.load(f) except FileNotFoundError: print(f"Error: Data file not found: {json_file}") exit(1) except json.JSONDecodeError: print(f"Error: Malformed JSON in {json_file}") exit(1) # MODIFIED: This new function runs in its own process, handling embedding requests. # It now accepts a model_id and loads the model itself. def embedding_server_process(model_id, device, request_queue, results_dict): """ A server process that loads the SigLIP model and continuously fetches text queries from a queue, computes their embeddings, and places the results in a shared dictionary. """ print(f"Embedding server started on PID {os.getpid()}...") print("Loading SigLIP model in the embedding server process...") model = AutoModel.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id, use_fast=True) print("SigLIP model loaded in server.") model.to(device) model.eval() while True: try: request_id, text_query = request_queue.get() if text_query == "STOP": print("Embedding server received stop signal. Shutting down.") break with torch.no_grad(): text_inputs = processor( text=[text_query], return_tensors="pt", padding=True, truncation=True, ).to(device) query_embedding = model.get_text_features(**text_inputs) # Move embedding to CPU before sharing across processes results_dict[request_id] = query_embedding.cpu() except Exception as e: print(f"Error in embedding server: {e}") traceback.print_exc() # MODIFIED: The old init_worker is removed. def main(): """Main function to orchestrate the evaluation framework.""" args = parse_arguments() print("--- Agentic Video QA with Hybrid Retrieval ---") print( f"Model: {args.target_model}, Data: {args.data_file}, Embeddings: {args.embeddings_path}" ) # MODIFIED: Changed start method to 'spawn' for safety with CUDA and on macOS/Windows. try: multiprocessing.set_start_method("spawn", force=True) print("Multiprocessing start method set to 'spawn'.") except RuntimeError: print("Start method already set.") # MODIFIED: Model is no longer loaded in the main process. # It will be loaded in the dedicated embedding_server_process. device = "cuda" if torch.cuda.is_available() else "cpu" model_name_safe = args.target_model.replace("/", "_") data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0] output_prefix = f"{model_name_safe}_{data_filename_base}_agent_hybrid" results_output_file = f"{output_prefix}_results.json" metrics_output_file = f"{output_prefix}_metrics.json" error_log_file = f"{output_prefix}_errors.log" with open(error_log_file, "a", encoding="utf-8") as f: f.write( f"\n=== Log Session Started at {datetime.now()} for {args.target_model} ===\n" ) all_test_data = load_test_data(args.data_file) existing_results = [] completed_ids = set() if os.path.exists(results_output_file): try: with open(results_output_file, "r", encoding="utf-8") as f: existing_results = json.load(f) if isinstance(existing_results, list): completed_ids = { item["uid"] for item in existing_results if "uid" in item } print(f"Found {len(completed_ids)} completed tasks. Resuming...") else: existing_results = [] except (json.JSONDecodeError, IOError): existing_results = [] tasks_to_process = [ item for item in all_test_data if item.get("uid") not in completed_ids ] if not tasks_to_process: print("All tasks are already completed. Calculating final metrics.") else: print( f"Total: {len(all_test_data)}. Completed: {len(completed_ids)}. To process: {len(tasks_to_process)}." ) all_results = list(existing_results) if tasks_to_process: # MODIFIED: Set up Manager, Queues, and the embedding server process with multiprocessing.Manager() as manager: request_queue = manager.Queue() results_dict = manager.dict() # MODIFIED: Start the dedicated embedding server process, passing the model ID. embedding_server = multiprocessing.Process( target=embedding_server_process, args=( SIGLIP_MODEL_ID, device, request_queue, results_dict, ), ) embedding_server.start() # MODIFIED: The ProcessPoolExecutor no longer needs an initializer for the model with concurrent.futures.ProcessPoolExecutor( max_workers=args.pool_processes ) as executor: # MODIFIED: Pass the queues to each worker via partial func = partial( process_single_data, args=args, request_queue=request_queue, results_dict=results_dict, ) results_iterator = executor.map(func, tasks_to_process) for result in tqdm( results_iterator, total=len(tasks_to_process), desc="Processing Videos", ): if result: if "error" in result: with open(error_log_file, "a", encoding="utf-8") as f: f.write( f"Error on key {result.get('key', 'N/A')}:\n Error: {result['error']}\n Traceback: {result['traceback']}\n---\n" ) all_results.append(result) if len(all_results) % 10 == 0: save_json_file(all_results, results_output_file) # MODIFIED: Gracefully shut down the embedding server print("All tasks processed. Sending stop signal to embedding server.") request_queue.put((None, "STOP")) embedding_server.join() print("\n\nProcessing complete.") save_json_file(all_results, results_output_file) print(f"Detailed results saved to: {results_output_file}") final_metrics = calculate_metrics(all_results) save_json_file(final_metrics, metrics_output_file) print(f"\nMetrics saved to: {metrics_output_file}") print(json.dumps(final_metrics, indent=4)) if __name__ == "__main__": main()