|
|
import os |
|
|
import json |
|
|
import base64 |
|
|
import argparse |
|
|
import time |
|
|
import re |
|
|
import traceback |
|
|
import uuid |
|
|
import multiprocessing |
|
|
import concurrent.futures |
|
|
from datetime import datetime |
|
|
from functools import partial |
|
|
|
|
|
import requests |
|
|
import torch |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
from openai import AzureOpenAI, OpenAI |
|
|
from volcenginesdkarkruntime import Ark |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
from torch.nn.functional import cosine_similarity |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SIGLIP_MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384" |
|
|
|
|
|
TOP_K_FRAMES = 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STEP_1_PLANNING_PROMPT = """ |
|
|
You are a professional video analyst. Your task is to analyze a question and a few initial video sample frames, then plan what keyframes you need to see to answer the question. |
|
|
|
|
|
Do not answer the question directly. Your output must be a JSON array, where each object represents a keyframe you wish to generate. |
|
|
Each object must contain the following two keys: |
|
|
1. `reference_image_id`: An integer representing the ID of a frame already provided to you that you wish to use as a generation reference. This ID must be one of the IDs provided by the user. |
|
|
2. `prompt`: A detailed text description to tell the image generation model what kind of scene to draw. |
|
|
|
|
|
For example, if the question is "Where did the man in the red shirt eventually go?", you might generate the following JSON: |
|
|
```json |
|
|
[ |
|
|
{ |
|
|
"reference_image_id": 120, |
|
|
"prompt": "A man in a red shirt is walking towards an open door, with a background similar to the reference image." |
|
|
}, |
|
|
{ |
|
|
"reference_image_id": 120, |
|
|
"prompt": "A man in a red shirt has already walked out the door, and the door is closing, with a background similar to the reference image." |
|
|
} |
|
|
] |
|
|
``` |
|
|
Your output must strictly adhere to this JSON format. |
|
|
""" |
|
|
|
|
|
|
|
|
STEP_3_FINAL_ANSWER_PROMPT = """ |
|
|
You are an AI video question-answering assistant. |
|
|
The user will provide you with a series of keyframes retrieved from a video and a question. |
|
|
|
|
|
First, please provide a step-by-step reasoning process, analyzing these keyframes and deriving your conclusion. |
|
|
After your reasoning, provide the final answer. The answer must be in a JSON code block, and the JSON object must contain a key "answer" with a value of one of 'A', 'B', 'C', or 'D'. |
|
|
|
|
|
Your output format must be strictly as follows: |
|
|
<Your step-by-step reasoning process> |
|
|
```json |
|
|
{"answer": "A"} |
|
|
``` |
|
|
Do not include any other text after the JSON block. |
|
|
""" |
|
|
|
|
|
|
|
|
def parse_arguments(): |
|
|
"""Parse command-line arguments""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Image Retrieval-based Video QA Workflow" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--target-model", "-tm", type=str, required=True, help="VLM model for inference (e.g., gpt-4o)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--frames-path", "-fp", type=str, required=True, help="Root directory containing video frame folders" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--data-file", "-df", type=str, required=True, help="JSON data file containing evaluation questions" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--embeddings-path", "-ep", type=str, required=True, help="Directory containing pre-computed embeddings for all video frames" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-path", "-op", type=str, default="./results_image_retrieval", help="Directory to store all outputs and generated images" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--initial-frames-num", "-ifn", type=int, default=8, help="Number of initial uniformly sampled frames for Step 1" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--max-retry-times", "-mr", type=int, default=10, help="Maximum number of retries for API calls" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--pool-processes", "-pp", type=int, default=10, help="Number of parallel processes" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--base_url", type=str, required=True, help="API Endpoint URL for the VLM model" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--api_key", type=str, required=True, help="API Key for the VLM model" |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def save_json_file(data, output_file): |
|
|
"""Save data to a JSON file""" |
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True) |
|
|
with open(output_file, "w", encoding="utf-8") as f: |
|
|
json.dump(data, f, indent=4, ensure_ascii=False) |
|
|
|
|
|
|
|
|
def extract_json_from_response(response, is_list=False): |
|
|
"""Extract a JSON object or list from the model's response text""" |
|
|
if not response: |
|
|
return None |
|
|
|
|
|
pattern = r"```json\s*([\{\[].*?[\]\}])\s*```" |
|
|
match = re.search(pattern, response, re.DOTALL) |
|
|
if match: |
|
|
json_str = match.group(1) |
|
|
try: |
|
|
return json.loads(json_str) |
|
|
except json.JSONDecodeError: |
|
|
print(f"JSON parsing failed: {json_str}") |
|
|
return None |
|
|
return None |
|
|
|
|
|
|
|
|
def calculate_metrics(results): |
|
|
"""Calculate accuracy and other metrics from evaluation results""" |
|
|
valid_results = [r for r in results if "error" not in r] |
|
|
total_samples = len(valid_results) |
|
|
if total_samples == 0: return {"accuracy": 0.0} |
|
|
|
|
|
answered = sum(1 for x in valid_results if x.get("model_answer") is not None) |
|
|
correct = sum(1 for x in valid_results if x.get("is_correct")) |
|
|
accuracy = correct / answered if answered > 0 else 0.0 |
|
|
|
|
|
return { |
|
|
"total_samples": total_samples, |
|
|
"answered_samples": answered, |
|
|
"correct_answers": correct, |
|
|
"accuracy": accuracy, |
|
|
} |
|
|
|
|
|
|
|
|
def call_vlm_api(client, messages, model, item_id, max_retry_times, json_schema=None): |
|
|
"""Call VLM API, with support for retries and structured output""" |
|
|
params = {"model": model, "messages": messages, "max_tokens": 4096} |
|
|
if json_schema: |
|
|
params["response_format"] = {"type": "json_object", "schema": json_schema} |
|
|
|
|
|
for retry in range(max_retry_times): |
|
|
try: |
|
|
completion = client.chat.completions.create(**params) |
|
|
return completion.choices[0].message.content |
|
|
except Exception as e: |
|
|
print(f"API Error (item {item_id}): {e}. Retrying ({retry + 1}/{max_retry_times})...") |
|
|
if retry == max_retry_times - 1: |
|
|
raise e |
|
|
time.sleep(5) |
|
|
|
|
|
|
|
|
def generate_image(reference_image_id, prompt, all_frame_paths, output_dir, generation_idx): |
|
|
"""Call the image generation API to create a new frame""" |
|
|
print(f"\n[Image Generation] Using Prompt: '{prompt}'") |
|
|
ark_api_key = os.environ.get("ARK_API_KEY") |
|
|
if not ark_api_key: |
|
|
raise ValueError("Environment variable ARK_API_KEY is not set.") |
|
|
|
|
|
client = Ark(base_url="https://ark.cn-beijing.volces.com/api/v3", api_key=ark_api_key) |
|
|
|
|
|
ref_image_path = all_frame_paths.get(reference_image_id) |
|
|
if not ref_image_path or not os.path.exists(ref_image_path): |
|
|
raise FileNotFoundError(f"Reference image ID {reference_image_id} not found.") |
|
|
|
|
|
try: |
|
|
ref_image_b64 = encode_image(ref_image_path) |
|
|
ref_image_data_uri = f"data:image/jpeg;base64,{ref_image_b64}" |
|
|
|
|
|
response = client.images.generate( |
|
|
model="doubao-seedream-4-0-250828", |
|
|
prompt=prompt, |
|
|
image=ref_image_data_uri, |
|
|
size="1024x1024", |
|
|
response_format="url", |
|
|
watermark=False, |
|
|
) |
|
|
image_url = response.data[0].url |
|
|
|
|
|
image_content = requests.get(image_url, timeout=60).content |
|
|
|
|
|
new_frame_filename = f"generated_frame_{generation_idx}_ref_{reference_image_id}.jpg" |
|
|
new_frame_path = os.path.join(output_dir, new_frame_filename) |
|
|
|
|
|
with open(new_frame_path, "wb") as f: |
|
|
f.write(image_content) |
|
|
|
|
|
print(f"[Image Generation Success] Image saved to: {new_frame_path}") |
|
|
return new_frame_path |
|
|
except Exception as e: |
|
|
print(f"Image generation or download failed: {e}") |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
def retrieve_frames_by_image_embedding( |
|
|
image_path, video_embeddings_data, request_queue, results_dict, k |
|
|
): |
|
|
"""Retrieve Top-K similar frames from the video using an image embedding""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
frame_filenames = video_embeddings_data["filenames"] |
|
|
frame_embeddings = video_embeddings_data["embeddings"].to(device) |
|
|
|
|
|
|
|
|
request_id = str(uuid.uuid4()) |
|
|
request_queue.put((request_id, image_path)) |
|
|
|
|
|
|
|
|
while request_id not in results_dict: |
|
|
time.sleep(0.05) |
|
|
query_embedding = results_dict.pop(request_id).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
similarities = cosine_similarity(query_embedding, frame_embeddings) |
|
|
top_k_indices = torch.topk(similarities, k=min(k, len(frame_filenames)), dim=-1).indices.cpu() |
|
|
|
|
|
|
|
|
video_frame_dir = os.path.dirname(frame_filenames[0]) |
|
|
top_k_paths = [os.path.join(video_frame_dir, video_embeddings_data['filenames'][i]) for i in top_k_indices] |
|
|
|
|
|
return top_k_paths |
|
|
|
|
|
def embedding_server_process(model_id, device, request_queue, results_dict): |
|
|
""" |
|
|
An independent server process that loads the SigLIP model and handles image embedding requests from worker processes. |
|
|
""" |
|
|
print(f"Embedding server started (PID: {os.getpid()})...") |
|
|
model = AutoModel.from_pretrained(model_id).to(device).eval() |
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
print("SigLIP model loaded in the embedding server.") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
request_id, image_path = request_queue.get() |
|
|
if image_path == "STOP": |
|
|
print("Embedding server received stop signal, shutting down.") |
|
|
break |
|
|
|
|
|
with torch.no_grad(): |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
inputs = processor(images=[image], return_tensors="pt").to(device) |
|
|
image_features = model.get_image_features(**inputs) |
|
|
results_dict[request_id] = image_features.cpu() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in embedding server: {e}") |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
def encode_image(image_path): |
|
|
"""Encode an image file to a Base64 string""" |
|
|
with open(image_path, "rb") as f: |
|
|
return base64.b64encode(f.read()).decode("utf-8") |
|
|
|
|
|
|
|
|
def uniformly_sample_frames_and_encode(frames_dir, num_frames): |
|
|
"""Uniformly sample frames and encode them, while also returning a mapping of frame IDs to paths""" |
|
|
if not os.path.isdir(frames_dir): return [], {} |
|
|
|
|
|
frame_files = sorted( |
|
|
[f for f in os.listdir(frames_dir) if f.endswith(".jpg")], |
|
|
key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)), |
|
|
) |
|
|
if not frame_files: return [], {} |
|
|
|
|
|
indices = [int(i * len(frame_files) / num_frames) for i in range(num_frames)] |
|
|
sampled_files = [frame_files[i] for i in indices] |
|
|
|
|
|
frame_path_map, encoded_frames = {}, [] |
|
|
for f in sampled_files: |
|
|
path = os.path.join(frames_dir, f) |
|
|
frame_id = int(re.search(r"frame_(\d+)\.jpg", f).group(1)) |
|
|
|
|
|
encoded_frames.extend([ |
|
|
{"type": "text", "text": f"This is Frame ID: {frame_id}"}, |
|
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(path)}"}} |
|
|
]) |
|
|
frame_path_map[frame_id] = path |
|
|
return encoded_frames, frame_path_map |
|
|
|
|
|
|
|
|
def run_workflow_for_item( |
|
|
data_item, args, request_queue, results_dict |
|
|
): |
|
|
"""Execute the complete three-step workflow for a single data item""" |
|
|
item_key = data_item["key"] |
|
|
print(f"\n--- Starting processing for video: {item_key} ---") |
|
|
|
|
|
|
|
|
generated_images_dir = os.path.join(args.output_path, "generated_images", item_key) |
|
|
os.makedirs(generated_images_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if "ark" in args.base_url: |
|
|
client = Ark(base_url=args.base_url, api_key=args.api_key) |
|
|
elif "aliyun" in args.base_url or "127.0.0.1" in args.base_url: |
|
|
client = OpenAI(api_key=args.api_key, base_url=args.base_url) |
|
|
else: |
|
|
client = AzureOpenAI(api_version="2023-05-15", api_key=args.api_key, azure_endpoint=args.base_url) |
|
|
|
|
|
|
|
|
print(f"[{item_key}] Step 1: Uniformly sampling and generating keyframe creation requests...") |
|
|
video_frames_path = os.path.join(args.frames_path, item_key) |
|
|
initial_frames_encoded, initial_frame_paths = uniformly_sample_frames_and_encode( |
|
|
video_frames_path, args.initial_frames_num |
|
|
) |
|
|
if not initial_frames_encoded: |
|
|
raise FileNotFoundError(f"Initial frames not found for video {item_key}.") |
|
|
|
|
|
planning_messages = [ |
|
|
{"role": "system", "content": STEP_1_PLANNING_PROMPT}, |
|
|
{"role": "user", "content": [ |
|
|
{"type": "text", "text": "Here are the initial sample frames and the question:"}, |
|
|
*initial_frames_encoded, |
|
|
{"type": "text", "text": f"Question: {data_item['question']}"} |
|
|
]} |
|
|
] |
|
|
|
|
|
|
|
|
planning_schema = { |
|
|
"type": "array", |
|
|
"items": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"reference_image_id": {"type": "integer"}, |
|
|
"prompt": {"type": "string"} |
|
|
}, |
|
|
"required": ["reference_image_id", "prompt"] |
|
|
} |
|
|
} |
|
|
|
|
|
raw_planning_response = call_vlm_api(client, planning_messages, args.target_model, item_key, args.max_retry_times) |
|
|
image_generation_requests = extract_json_from_response(raw_planning_response, is_list=True) |
|
|
|
|
|
if not image_generation_requests or not isinstance(image_generation_requests, list): |
|
|
raise ValueError(f"Step 1 failed to generate valid JSON-formatted image generation requests. Response: {raw_planning_response}") |
|
|
|
|
|
print(f"[{item_key}] Successfully generated {len(image_generation_requests)} keyframe generation requests.") |
|
|
|
|
|
|
|
|
valid_ids = list(initial_frame_paths.keys()) |
|
|
if not valid_ids: |
|
|
raise ValueError(f"No valid initial frame IDs found for video {item_key}.") |
|
|
|
|
|
for req in image_generation_requests: |
|
|
original_id = req.get("reference_image_id") |
|
|
if original_id not in valid_ids: |
|
|
closest_id = min(valid_ids, key=lambda valid_id: abs(valid_id - original_id)) |
|
|
print(f"Warning: Model generated a non-existent reference_image_id: {original_id}. Substituting with the closest valid ID: {closest_id}.") |
|
|
req["reference_image_id"] = closest_id |
|
|
|
|
|
|
|
|
print(f"[{item_key}] Step 2: Generating images and retrieving similar frames...") |
|
|
all_retrieved_frame_paths = set() |
|
|
generated_image_paths = [] |
|
|
video_embedding_file = os.path.join(args.embeddings_path, f"{item_key}.pt") |
|
|
if not os.path.exists(video_embedding_file): |
|
|
raise FileNotFoundError(f"Embedding file for video {item_key} not found: {video_embedding_file}") |
|
|
video_embeddings_data = torch.load(video_embedding_file, map_location="cpu") |
|
|
|
|
|
|
|
|
video_frame_dir_for_embeddings = os.path.join(args.frames_path, item_key) |
|
|
video_embeddings_data['filenames'] = [os.path.join(video_frame_dir_for_embeddings, os.path.basename(f)) for f in video_embeddings_data['filenames']] |
|
|
|
|
|
|
|
|
for i, req in enumerate(image_generation_requests): |
|
|
|
|
|
generated_path = generate_image( |
|
|
reference_image_id=req["reference_image_id"], |
|
|
prompt=req["prompt"], |
|
|
all_frame_paths=initial_frame_paths, |
|
|
output_dir=generated_images_dir, |
|
|
generation_idx=i + 1, |
|
|
) |
|
|
|
|
|
path_for_retrieval = None |
|
|
if generated_path: |
|
|
generated_image_paths.append(generated_path) |
|
|
path_for_retrieval = generated_path |
|
|
else: |
|
|
print(f"Warning: Generation failed for image {i+1}. Using its reference image (ID: {req['reference_image_id']}) for retrieval instead.") |
|
|
path_for_retrieval = initial_frame_paths.get(req["reference_image_id"]) |
|
|
|
|
|
if not path_for_retrieval: |
|
|
print(f"Error: Could not find a path for retrieval for request {i+1}. Skipping.") |
|
|
continue |
|
|
|
|
|
|
|
|
retrieved_paths = retrieve_frames_by_image_embedding( |
|
|
path_for_retrieval, video_embeddings_data, request_queue, results_dict, k=TOP_K_FRAMES |
|
|
) |
|
|
all_retrieved_frame_paths.update(retrieved_paths) |
|
|
print(f"[{item_key}] Retrieval {i+1}/{len(image_generation_requests)} complete, found {len(retrieved_paths)} frames.") |
|
|
|
|
|
if not all_retrieved_frame_paths: |
|
|
raise ValueError(f"Failed to retrieve any frames for video {item_key}.") |
|
|
|
|
|
print(f"[{item_key}] Step 2 complete. Retrieved a total of {len(all_retrieved_frame_paths)} unique keyframes.") |
|
|
|
|
|
|
|
|
print(f"[{item_key}] Step 3: Consolidating keyframes for final reasoning...") |
|
|
final_frames_encoded = [] |
|
|
for path in sorted(list(all_retrieved_frame_paths)): |
|
|
final_frames_encoded.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(path)}"}}) |
|
|
|
|
|
final_messages = [ |
|
|
{"role": "system", "content": STEP_3_FINAL_ANSWER_PROMPT}, |
|
|
{"role": "user", "content": [ |
|
|
{"type": "text", "text": "Here are all the keyframes retrieved for you. Please answer the question based on them."}, |
|
|
*final_frames_encoded, |
|
|
{"type": "text", "text": f"Question: {data_item['question']}"} |
|
|
]} |
|
|
] |
|
|
|
|
|
final_response_text = call_vlm_api(client, final_messages, args.target_model, item_key, args.max_retry_times) |
|
|
|
|
|
|
|
|
parsed_answer = extract_json_from_response(final_response_text) |
|
|
model_answer = parsed_answer.get("answer", "").strip().upper() if parsed_answer else None |
|
|
is_correct = (model_answer == data_item["answer"].strip().upper()) if model_answer else False |
|
|
|
|
|
result = { |
|
|
**data_item, |
|
|
"workflow_steps": { |
|
|
"step1_planning_requests": image_generation_requests, |
|
|
"step2_generated_images": generated_image_paths, |
|
|
"step2_retrieved_frame_paths": sorted(list(all_retrieved_frame_paths)), |
|
|
"step3_final_reasoning_and_answer": final_response_text, |
|
|
}, |
|
|
"model_answer": model_answer, |
|
|
"is_correct": is_correct, |
|
|
} |
|
|
return result |
|
|
|
|
|
|
|
|
def process_single_data_wrapper(data_item, args, request_queue, results_dict): |
|
|
"""Wrapper function to process a single data item, used for exception handling""" |
|
|
try: |
|
|
return run_workflow_for_item(data_item, args, request_queue, results_dict) |
|
|
except Exception as e: |
|
|
print(f"\nA critical error occurred while processing video {data_item['key']}: {e}") |
|
|
traceback.print_exc() |
|
|
return { |
|
|
"key": data_item['key'], |
|
|
"uid": data_item.get('uid'), |
|
|
"error": str(e), |
|
|
"traceback": traceback.format_exc(), |
|
|
} |
|
|
|
|
|
def main(): |
|
|
"""Main function to orchestrate the entire evaluation workflow""" |
|
|
args = parse_arguments() |
|
|
print("--- Image Retrieval-based Video QA Workflow Starting ---") |
|
|
print(f"Evaluating Model: {args.target_model}, Dataset: {args.data_file}") |
|
|
|
|
|
try: |
|
|
multiprocessing.set_start_method("spawn", force=True) |
|
|
except RuntimeError: |
|
|
pass |
|
|
|
|
|
os.makedirs(args.output_path, exist_ok=True) |
|
|
|
|
|
|
|
|
model_safe_name = args.target_model.replace("/", "_") |
|
|
data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0] |
|
|
output_prefix = f"{model_safe_name}_{data_filename_base}_image_retrieval_{args.initial_frames_num}frames" |
|
|
|
|
|
results_file = os.path.join(args.output_path, f"{output_prefix}_results.json") |
|
|
metrics_file = os.path.join(args.output_path, f"{output_prefix}_metrics.json") |
|
|
|
|
|
test_data = load_test_data(args.data_file) |
|
|
all_results = [] |
|
|
|
|
|
with multiprocessing.Manager() as manager: |
|
|
request_queue = manager.Queue() |
|
|
results_dict = manager.dict() |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
embedding_server = multiprocessing.Process( |
|
|
target=embedding_server_process, |
|
|
args=(SIGLIP_MODEL_ID, device, request_queue, results_dict), |
|
|
) |
|
|
embedding_server.start() |
|
|
|
|
|
|
|
|
time.sleep(15) |
|
|
|
|
|
with concurrent.futures.ProcessPoolExecutor(max_workers=args.pool_processes) as executor: |
|
|
func = partial( |
|
|
process_single_data_wrapper, |
|
|
args=args, |
|
|
request_queue=request_queue, |
|
|
results_dict=results_dict |
|
|
) |
|
|
|
|
|
results_iterator = executor.map(func, test_data) |
|
|
|
|
|
for result in tqdm(results_iterator, total=len(test_data), desc="Processing Videos"): |
|
|
if result: |
|
|
all_results.append(result) |
|
|
|
|
|
if len(all_results) % 10 == 0: |
|
|
save_json_file(all_results, results_file) |
|
|
|
|
|
|
|
|
print("All tasks completed. Shutting down the embedding server...") |
|
|
request_queue.put((None, "STOP")) |
|
|
embedding_server.join() |
|
|
|
|
|
print("\n--- All Videos Processed ---") |
|
|
save_json_file(all_results, results_file) |
|
|
print(f"Detailed results saved to: {results_file}") |
|
|
|
|
|
final_metrics = calculate_metrics(all_results) |
|
|
save_json_file(final_metrics, metrics_file) |
|
|
print(f"Final evaluation metrics saved to: {metrics_file}") |
|
|
print(json.dumps(final_metrics, indent=4)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
main() |
|
|
|
|
|
|