|
|
import os |
|
|
import json |
|
|
import base64 |
|
|
import argparse |
|
|
import time |
|
|
import re |
|
|
import traceback |
|
|
from datetime import datetime |
|
|
from functools import partial |
|
|
from openai import AzureOpenAI, OpenAI |
|
|
from volcenginesdkarkruntime import Ark |
|
|
import concurrent.futures |
|
|
from tqdm import tqdm |
|
|
import torch |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
from torch.nn.functional import cosine_similarity |
|
|
|
|
|
import multiprocessing |
|
|
import uuid |
|
|
|
|
|
|
|
|
|
|
|
SIGLIP_MODEL_ID = ( |
|
|
"/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384" |
|
|
) |
|
|
|
|
|
|
|
|
AGENT_SYSTEM_PROMPT = """ |
|
|
You are an intelligent AI assistant specialized in video question answering. |
|
|
Your task is to answer a multiple-choice question based on a video by strategically retrieving and analyzing its frames. |
|
|
|
|
|
You have two tools to retrieve frames. Both return images directly. |
|
|
|
|
|
1. `get_frames_by_id(frame_ids)`: Retrieves frames using their specific numerical IDs. Use this when the question provides direct temporal clues or when you need to view specific frames identified by another tool. |
|
|
* **Example Use Case:** For a question like "What happens at the 1 minute 30 second mark?", you can calculate the approximate frame ID and use this tool to see the visual. |
|
|
* **Example Use Case:** For "Describe the action in frame 550.", you would call this tool with `frame_ids=[550]`. |
|
|
|
|
|
2. `get_frames_by_similarity(query)`: Searches the entire video for frames that visually match a text description and returns the top 5 most relevant frames directly. Use this for content-based questions where the timing is unknown. |
|
|
* **Example Use Case:** For a question like "What color is the main character's car?", you would use this tool with a query like "the main character's car". |
|
|
* **Example Use Case:** For "Find the scene where a band is playing on stage", you would use the query "a band playing on stage". |
|
|
|
|
|
Your strategy must be efficient: |
|
|
1. **Analyze the Query:** First, determine if the question is temporal/logical (better for `get_frames_by_id`) or content-based (requires `get_frames_by_similarity`). |
|
|
2. **Retrieve & Analyze:** Call the most appropriate tool. Analyze the returned frames to form a hypothesis. |
|
|
3. **Iterate:** If you need more information, refine your search query for the similarity tool or calculate new frame IDs for the ID tool and call again. |
|
|
4. **Final Answer:** Once you have gathered enough visual evidence, provide your step-by-step reasoning and then the final answer in the specified JSON format. Do not guess. |
|
|
|
|
|
Your output should follow this format exactly: |
|
|
<Your step-by-step reasoning here> |
|
|
```json |
|
|
{"answer": "X"} |
|
|
``` |
|
|
Do not include any other text after the JSON block. |
|
|
""" |
|
|
|
|
|
|
|
|
GET_FRAMES_BY_ID_TOOL_SCHEMA = { |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "get_frames_by_id", |
|
|
"description": "Retrieves specific video frames by their numerical IDs to get visual information.", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"frame_ids": { |
|
|
"type": "array", |
|
|
"items": {"type": "integer"}, |
|
|
"description": "A list of up to 10 frame numbers to retrieve.", |
|
|
}, |
|
|
}, |
|
|
"required": ["frame_ids"], |
|
|
}, |
|
|
}, |
|
|
} |
|
|
|
|
|
GET_FRAMES_BY_SIMILARITY_TOOL_SCHEMA = { |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "get_frames_by_similarity", |
|
|
"description": "Searches for and retrieves the top 5 most visually relevant frames for a given text query. Use this to locate visual content when frame numbers are unknown.", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"query": { |
|
|
"type": "string", |
|
|
"description": "A concise text description of the visual content to search for (e.g., 'a person playing piano').", |
|
|
}, |
|
|
}, |
|
|
"required": ["query"], |
|
|
}, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def parse_arguments(): |
|
|
"""Parse command line arguments.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Agentic Video QA with Hybrid Frame Retrieval" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--target-model", "-tm", type=str, required=True, help="Model to evaluate." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--frames-path", |
|
|
"-fp", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Base directory for video frames.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--data-file", |
|
|
"-df", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to the evaluation dataset.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--embeddings-path", |
|
|
"-ep", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Directory with pre-computed frame embeddings.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max-retry-times", |
|
|
"-mr", |
|
|
type=int, |
|
|
default=10, |
|
|
help="Max retries for API calls.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pool-processes", |
|
|
"-pp", |
|
|
type=int, |
|
|
default=20, |
|
|
help="Number of parallel processes.", |
|
|
) |
|
|
parser.add_argument("--base_url", type=str, required=True, help="API endpoint URL.") |
|
|
parser.add_argument("--api_key", type=str, required=True, help="API key.") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def save_json_file(data, output_file): |
|
|
"""Saves data to a JSON file.""" |
|
|
with open(output_file, "w", encoding="utf-8") as f: |
|
|
json.dump(data, f, indent=4) |
|
|
|
|
|
|
|
|
def extract_json_from_response(response): |
|
|
"""Extracts a JSON object from a model's response string.""" |
|
|
if not response: |
|
|
return None |
|
|
match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL) |
|
|
if match: |
|
|
try: |
|
|
return json.loads(match.group(1)) |
|
|
except (json.JSONDecodeError, IndexError): |
|
|
return None |
|
|
return None |
|
|
|
|
|
|
|
|
def calculate_metrics(results): |
|
|
"""Calculates accuracy and other metrics from evaluation results.""" |
|
|
valid_results = [r for r in results if "error" not in r] |
|
|
total_samples = len(valid_results) |
|
|
if total_samples == 0: |
|
|
return { |
|
|
"total_samples": 0, |
|
|
"answered_samples": 0, |
|
|
"correct_answers": 0, |
|
|
"accuracy": 0.0, |
|
|
} |
|
|
answered_samples = sum( |
|
|
1 for x in valid_results if x.get("model_answer") is not None |
|
|
) |
|
|
correct_answers = sum(1 for x in valid_results if x.get("is_correct")) |
|
|
accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0 |
|
|
return { |
|
|
"total_samples": total_samples, |
|
|
"answered_samples": answered_samples, |
|
|
"correct_answers": correct_answers, |
|
|
"accuracy": accuracy, |
|
|
} |
|
|
|
|
|
|
|
|
def call_single_model(client, messages, model, item_id, max_retry_times, tools=None): |
|
|
"""Makes a single API call with retry logic and tool support.""" |
|
|
params = {"model": model, "messages": messages, "max_tokens": 4096} |
|
|
if tools: |
|
|
params["tools"] = tools |
|
|
params["tool_choice"] = "auto" |
|
|
|
|
|
for retry in range(max_retry_times): |
|
|
try: |
|
|
completion = client.chat.completions.create(**params) |
|
|
return completion.choices[0].message |
|
|
except Exception as e: |
|
|
print( |
|
|
f"API Error for item {item_id}: {str(e)}. Retrying ({retry + 1}/{max_retry_times})..." |
|
|
) |
|
|
if retry == max_retry_times - 1: |
|
|
raise e |
|
|
time.sleep(5) |
|
|
|
|
|
|
|
|
def get_frames_by_id(frame_ids: list, all_frame_paths: list): |
|
|
"""Tool implementation: Retrieves and encodes frames from a list of IDs.""" |
|
|
retrieved_frames = [] |
|
|
frame_map = { |
|
|
int(re.search(r"frame_(\d+)\.jpg", os.path.basename(p)).group(1)): p |
|
|
for p in all_frame_paths |
|
|
if re.search(r"frame_(\d+)\.jpg", os.path.basename(p)) |
|
|
} |
|
|
for fid in frame_ids: |
|
|
path = frame_map.get(fid) |
|
|
if path and os.path.exists(path): |
|
|
b64_image = encode_image(path) |
|
|
retrieved_frames.append( |
|
|
{ |
|
|
"type": "image_url", |
|
|
"image_url": {"url": f"data:image/jpeg;base64,{b64_image}"}, |
|
|
} |
|
|
) |
|
|
return retrieved_frames |
|
|
|
|
|
|
|
|
|
|
|
def get_frames_by_similarity( |
|
|
query: str, |
|
|
all_frame_paths: list, |
|
|
precomputed_data: dict, |
|
|
request_queue: multiprocessing.Queue, |
|
|
results_dict: dict, |
|
|
k: int = 5, |
|
|
): |
|
|
""" |
|
|
Requests a text embedding from the server process, calculates similarity, |
|
|
finds top-k frames, and returns them encoded. |
|
|
""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
frame_filenames = precomputed_data["filenames"] |
|
|
frame_embeddings = precomputed_data["embeddings"].to(device) |
|
|
|
|
|
|
|
|
request_id = str(uuid.uuid4()) |
|
|
request_queue.put((request_id, query)) |
|
|
|
|
|
|
|
|
while request_id not in results_dict: |
|
|
time.sleep(0.05) |
|
|
query_embedding = results_dict.pop(request_id).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
similarities = cosine_similarity(query_embedding, frame_embeddings) |
|
|
|
|
|
num_frames_to_select = min(k, len(frame_filenames)) |
|
|
top_k_indices = ( |
|
|
torch.topk(similarities, k=num_frames_to_select, dim=-1) |
|
|
.indices.cpu() |
|
|
.flatten() |
|
|
.numpy() |
|
|
) |
|
|
|
|
|
top_k_filenames = [frame_filenames[i] for i in top_k_indices] |
|
|
top_k_frame_ids = [ |
|
|
int(re.search(r"frame_(\d+)\.jpg", f).group(1)) for f in top_k_filenames |
|
|
] |
|
|
|
|
|
retrieved_frames = get_frames_by_id(top_k_frame_ids, all_frame_paths) |
|
|
return retrieved_frames |
|
|
|
|
|
|
|
|
def evaluate_single_item_agentic( |
|
|
data_item, |
|
|
all_frame_paths, |
|
|
embeddings_data, |
|
|
target_model, |
|
|
api_key, |
|
|
base_url, |
|
|
max_retry_times, |
|
|
request_queue, |
|
|
results_dict, |
|
|
): |
|
|
"""Evaluates a single item using an agentic loop.""" |
|
|
if "ark" in base_url: |
|
|
client = Ark(base_url=base_url, api_key=api_key) |
|
|
elif "aliyun" in base_url or "127.0.0.1" in base_url: |
|
|
client = OpenAI(api_key=api_key, base_url=base_url) |
|
|
else: |
|
|
client = AzureOpenAI( |
|
|
api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url |
|
|
) |
|
|
|
|
|
tools = [GET_FRAMES_BY_ID_TOOL_SCHEMA, GET_FRAMES_BY_SIMILARITY_TOOL_SCHEMA] |
|
|
|
|
|
get_frames_by_id_with_context = partial( |
|
|
get_frames_by_id, all_frame_paths=all_frame_paths |
|
|
) |
|
|
|
|
|
get_frames_by_similarity_with_context = partial( |
|
|
get_frames_by_similarity, |
|
|
all_frame_paths=all_frame_paths, |
|
|
precomputed_data=embeddings_data, |
|
|
request_queue=request_queue, |
|
|
results_dict=results_dict, |
|
|
) |
|
|
|
|
|
available_functions = { |
|
|
"get_frames_by_id": get_frames_by_id_with_context, |
|
|
"get_frames_by_similarity": get_frames_by_similarity_with_context, |
|
|
} |
|
|
|
|
|
total_frames = len(all_frame_paths) |
|
|
duration = data_item.get("video_info", {}).get("duration_minutes", 0) * 60 |
|
|
initial_prompt = ( |
|
|
f"The video has {total_frames} frames (ID 1 to {total_frames}) and is {duration:.0f} seconds long. " |
|
|
f"Please answer this question:\n{data_item['question']}" |
|
|
) |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": AGENT_SYSTEM_PROMPT}, |
|
|
{"role": "user", "content": initial_prompt}, |
|
|
] |
|
|
response_content = None |
|
|
max_tool_calls = 10 |
|
|
|
|
|
for _ in range(max_tool_calls): |
|
|
response_message = call_single_model( |
|
|
client, |
|
|
messages, |
|
|
target_model, |
|
|
data_item["key"], |
|
|
max_retry_times, |
|
|
tools=tools, |
|
|
) |
|
|
if response_message is None: |
|
|
return None |
|
|
|
|
|
messages.append(response_message) |
|
|
|
|
|
if response_message.tool_calls: |
|
|
for tool_call in response_message.tool_calls: |
|
|
function_name = tool_call.function.name |
|
|
function_to_call = available_functions.get(function_name) |
|
|
if function_to_call: |
|
|
function_args = json.loads(tool_call.function.arguments) |
|
|
function_response = function_to_call(**function_args) |
|
|
|
|
|
messages.append( |
|
|
{ |
|
|
"tool_call_id": tool_call.id, |
|
|
"role": "tool", |
|
|
"name": function_name, |
|
|
"content": json.dumps( |
|
|
{ |
|
|
"status": "success", |
|
|
"retrieved_frame_count": len(function_response), |
|
|
} |
|
|
), |
|
|
} |
|
|
) |
|
|
|
|
|
user_message_with_frames = [ |
|
|
{ |
|
|
"type": "text", |
|
|
"text": f"Here are the {len(function_response)} frames from your call to `{function_name}`.", |
|
|
} |
|
|
] |
|
|
user_message_with_frames.extend(function_response) |
|
|
messages.append( |
|
|
{"role": "user", "content": user_message_with_frames} |
|
|
) |
|
|
else: |
|
|
response_content = response_message.content |
|
|
break |
|
|
|
|
|
if response_content is None: |
|
|
final_prompt = "You have reached the maximum number of tool calls. Provide a final answer based on the information gathered so far." |
|
|
messages.append({"role": "user", "content": final_prompt}) |
|
|
final_response = call_single_model( |
|
|
client, messages, target_model, data_item["key"], max_retry_times |
|
|
) |
|
|
response_content = ( |
|
|
final_response.content |
|
|
if final_response |
|
|
else "Could not determine an answer after max tool calls." |
|
|
) |
|
|
|
|
|
is_correct = False |
|
|
model_answer_cleaned = None |
|
|
parsed_json = extract_json_from_response(response_content) |
|
|
if parsed_json and "answer" in parsed_json: |
|
|
model_answer_cleaned = str(parsed_json["answer"]).strip().upper() |
|
|
if model_answer_cleaned == data_item["answer"].strip().upper(): |
|
|
is_correct = True |
|
|
|
|
|
return { |
|
|
**data_item, |
|
|
"agent_conversation": [ |
|
|
msg if isinstance(msg, dict) else msg.model_dump() for msg in messages |
|
|
], |
|
|
"model_reasoning_and_answer": response_content, |
|
|
"model_answer": model_answer_cleaned, |
|
|
"is_correct": is_correct, |
|
|
} |
|
|
|
|
|
|
|
|
def encode_image(image_path): |
|
|
"""Encodes an image file to a base64 string.""" |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
|
|
|
|
|
|
|
|
def process_single_data(data_item, args, request_queue, results_dict): |
|
|
"""Main processing function for a single video, executed by a worker.""" |
|
|
item_key = data_item["key"] |
|
|
try: |
|
|
specific_frames_path = os.path.join(args.frames_path, item_key) |
|
|
embedding_file = os.path.join(args.embeddings_path, f"{item_key}.pt") |
|
|
|
|
|
if not os.path.isdir(specific_frames_path): |
|
|
raise FileNotFoundError( |
|
|
f"Frame directory not found: {specific_frames_path}" |
|
|
) |
|
|
if not os.path.exists(embedding_file): |
|
|
raise FileNotFoundError(f"Embedding file not found: {embedding_file}") |
|
|
|
|
|
all_frame_paths = sorted( |
|
|
[ |
|
|
os.path.join(specific_frames_path, f) |
|
|
for f in os.listdir(specific_frames_path) |
|
|
if f.endswith(".jpg") |
|
|
], |
|
|
key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)), |
|
|
) |
|
|
if not all_frame_paths: |
|
|
raise FileNotFoundError(f"No frames found for key '{item_key}'") |
|
|
|
|
|
embeddings_data = torch.load(embedding_file, map_location="cpu") |
|
|
|
|
|
|
|
|
result = evaluate_single_item_agentic( |
|
|
data_item, |
|
|
all_frame_paths, |
|
|
embeddings_data, |
|
|
args.target_model, |
|
|
args.api_key, |
|
|
args.base_url, |
|
|
args.max_retry_times, |
|
|
request_queue, |
|
|
results_dict, |
|
|
) |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nCRITICAL ERROR on key {item_key}: {str(e)}") |
|
|
traceback.print_exc() |
|
|
return { |
|
|
"key": item_key, |
|
|
"uid": data_item.get("uid"), |
|
|
"error": str(e), |
|
|
"traceback": traceback.format_exc(), |
|
|
} |
|
|
|
|
|
|
|
|
def load_test_data(json_file): |
|
|
"""Loads the evaluation data from a JSON file.""" |
|
|
try: |
|
|
with open(json_file, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except FileNotFoundError: |
|
|
print(f"Error: Data file not found: {json_file}") |
|
|
exit(1) |
|
|
except json.JSONDecodeError: |
|
|
print(f"Error: Malformed JSON in {json_file}") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def embedding_server_process(model_id, device, request_queue, results_dict): |
|
|
""" |
|
|
A server process that loads the SigLIP model and continuously fetches |
|
|
text queries from a queue, computes their embeddings, and places the |
|
|
results in a shared dictionary. |
|
|
""" |
|
|
print(f"Embedding server started on PID {os.getpid()}...") |
|
|
print("Loading SigLIP model in the embedding server process...") |
|
|
model = AutoModel.from_pretrained(model_id) |
|
|
processor = AutoProcessor.from_pretrained(model_id, use_fast=True) |
|
|
print("SigLIP model loaded in server.") |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
while True: |
|
|
try: |
|
|
request_id, text_query = request_queue.get() |
|
|
if text_query == "STOP": |
|
|
print("Embedding server received stop signal. Shutting down.") |
|
|
break |
|
|
|
|
|
with torch.no_grad(): |
|
|
text_inputs = processor( |
|
|
text=[text_query], |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
).to(device) |
|
|
query_embedding = model.get_text_features(**text_inputs) |
|
|
|
|
|
results_dict[request_id] = query_embedding.cpu() |
|
|
except Exception as e: |
|
|
print(f"Error in embedding server: {e}") |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to orchestrate the evaluation framework.""" |
|
|
args = parse_arguments() |
|
|
print("--- Agentic Video QA with Hybrid Retrieval ---") |
|
|
print( |
|
|
f"Model: {args.target_model}, Data: {args.data_file}, Embeddings: {args.embeddings_path}" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
multiprocessing.set_start_method("spawn", force=True) |
|
|
print("Multiprocessing start method set to 'spawn'.") |
|
|
except RuntimeError: |
|
|
print("Start method already set.") |
|
|
|
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model_name_safe = args.target_model.replace("/", "_") |
|
|
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0] |
|
|
output_prefix = f"{model_name_safe}_{data_filename_base}_agent_hybrid" |
|
|
results_output_file = f"{output_prefix}_results.json" |
|
|
metrics_output_file = f"{output_prefix}_metrics.json" |
|
|
error_log_file = f"{output_prefix}_errors.log" |
|
|
|
|
|
with open(error_log_file, "a", encoding="utf-8") as f: |
|
|
f.write( |
|
|
f"\n=== Log Session Started at {datetime.now()} for {args.target_model} ===\n" |
|
|
) |
|
|
|
|
|
all_test_data = load_test_data(args.data_file) |
|
|
existing_results = [] |
|
|
completed_ids = set() |
|
|
if os.path.exists(results_output_file): |
|
|
try: |
|
|
with open(results_output_file, "r", encoding="utf-8") as f: |
|
|
existing_results = json.load(f) |
|
|
if isinstance(existing_results, list): |
|
|
completed_ids = { |
|
|
item["uid"] for item in existing_results if "uid" in item |
|
|
} |
|
|
print(f"Found {len(completed_ids)} completed tasks. Resuming...") |
|
|
else: |
|
|
existing_results = [] |
|
|
except (json.JSONDecodeError, IOError): |
|
|
existing_results = [] |
|
|
|
|
|
tasks_to_process = [ |
|
|
item for item in all_test_data if item.get("uid") not in completed_ids |
|
|
] |
|
|
if not tasks_to_process: |
|
|
print("All tasks are already completed. Calculating final metrics.") |
|
|
else: |
|
|
print( |
|
|
f"Total: {len(all_test_data)}. Completed: {len(completed_ids)}. To process: {len(tasks_to_process)}." |
|
|
) |
|
|
|
|
|
all_results = list(existing_results) |
|
|
|
|
|
if tasks_to_process: |
|
|
|
|
|
with multiprocessing.Manager() as manager: |
|
|
request_queue = manager.Queue() |
|
|
results_dict = manager.dict() |
|
|
|
|
|
|
|
|
embedding_server = multiprocessing.Process( |
|
|
target=embedding_server_process, |
|
|
args=( |
|
|
SIGLIP_MODEL_ID, |
|
|
device, |
|
|
request_queue, |
|
|
results_dict, |
|
|
), |
|
|
) |
|
|
embedding_server.start() |
|
|
|
|
|
|
|
|
with concurrent.futures.ProcessPoolExecutor( |
|
|
max_workers=args.pool_processes |
|
|
) as executor: |
|
|
|
|
|
func = partial( |
|
|
process_single_data, |
|
|
args=args, |
|
|
request_queue=request_queue, |
|
|
results_dict=results_dict, |
|
|
) |
|
|
results_iterator = executor.map(func, tasks_to_process) |
|
|
for result in tqdm( |
|
|
results_iterator, |
|
|
total=len(tasks_to_process), |
|
|
desc="Processing Videos", |
|
|
): |
|
|
if result: |
|
|
if "error" in result: |
|
|
with open(error_log_file, "a", encoding="utf-8") as f: |
|
|
f.write( |
|
|
f"Error on key {result.get('key', 'N/A')}:\n Error: {result['error']}\n Traceback: {result['traceback']}\n---\n" |
|
|
) |
|
|
all_results.append(result) |
|
|
if len(all_results) % 10 == 0: |
|
|
save_json_file(all_results, results_output_file) |
|
|
|
|
|
|
|
|
print("All tasks processed. Sending stop signal to embedding server.") |
|
|
request_queue.put((None, "STOP")) |
|
|
embedding_server.join() |
|
|
|
|
|
print("\n\nProcessing complete.") |
|
|
save_json_file(all_results, results_output_file) |
|
|
print(f"Detailed results saved to: {results_output_file}") |
|
|
|
|
|
final_metrics = calculate_metrics(all_results) |
|
|
save_json_file(final_metrics, metrics_output_file) |
|
|
print(f"\nMetrics saved to: {metrics_output_file}") |
|
|
print(json.dumps(final_metrics, indent=4)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|