File size: 22,504 Bytes
608eb1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
import os
import json
import base64
import argparse
import time
import re
import traceback
from datetime import datetime
from functools import partial
import requests  # Import requests library to download images from URLs
from openai import AzureOpenAI, OpenAI
from volcenginesdkarkruntime import Ark
import concurrent.futures
from tqdm import tqdm

# 1. New Agent System Prompt
# Defines the agent's role and principles, guiding it to use the "imagination" tool when visual evidence is insufficient.
IMAGINE_AGENT_SYSTEM_PROMPT = """
You are an intelligent AI assistant specializing in answering video question-answering problems through reasoning and imagination.
Your task is to answer a multiple-choice question based on an initial, limited set of video frames.

You will receive a few uniformly sampled frames to get a basic understanding of the video.
These frames may not contain all the visual evidence needed to directly answer the question.

If the provided frame information is insufficient, you must use the `imagine_frame` tool to generate new, imagined frames to fill in the visual gaps and aid your reasoning.
You can call this tool multiple times to construct a sequence of imagined events.

Your strategy should be:
1. Analyze the initial frames and the user's question.
2. Form a hypothesis about the missing content.
3. If you need more visual information, call the `imagine_frame` tool. Provide a text `prompt` describing the scene you want to imagine, and select a `reference_image_id` from existing frames. The `reference_image_id` MUST be one of the IDs explicitly provided to you in the conversation history (e.g., "Frame ID: X" or "New Frame ID: Y"). Do not invent or assume frame IDs.
4. Analyze the newly generated frame in conjunction with the existing ones.
5. Continue this process of reasoning and imagination until you are confident in your answer. Please ensure you have found or created the relevant visual cues before answering the question.
6. Each tool call can only generate one frame.

IMPORTANT: Your text `prompt` for image generation must be safe and general. Avoid descriptions that could be interpreted as sensitive, harmful, or explicit to prevent generation failures.

After your reasoning, provide the final answer in a JSON code block. The JSON object must contain a key "answer" with a value of one of 'A', 'B', 'C', or 'D'.

Your output must strictly follow this format:
<Your step-by-step reasoning process here, including why you chose to imagine a certain frame>
```json
{"answer": "X"}
```
Do not include any other text after the JSON code block.
"""

# 2. New Tool Schema for imagine_frame
# Defines the interface, parameters, and description for the `imagine_frame` tool.
IMAGINE_FRAME_TOOL_SCHEMA = {
    "type": "function",
    "function": {
        "name": "imagine_frame",
        "description": "When visual evidence is insufficient, generates a new image based on a text prompt and a reference image to help answer the question. Use it to imagine what might have happened between the provided frames.",
        "parameters": {
            "type": "object",
            "properties": {
                "reference_image_id": {
                    "type": "integer",
                    "description": "The ID of an existing frame to use as a style and content reference. It can be one of the original frames or a previously generated one.",
                },
                "prompt": {
                    "type": "string",
                    "description": "A detailed text description of the frame you want to imagine and generate.",
                },
            },
            "required": ["reference_image_id", "prompt"],
        },
    },
}


# 3. Implementation of the `imagine_frame` tool
def imagine_frame(
    reference_image_id: int,
    prompt: str,
    all_frame_paths: dict,
    output_dir: str,
    generation_count: int,
):
    """
    Tool implementation: Calls an image generation model to create a new frame.

    Args:
        reference_image_id (int): The ID of the reference frame.
        prompt (str): The text prompt for image generation.
        all_frame_paths (dict): A dictionary containing IDs and paths of all currently available frames (original + generated).
        output_dir (str): The directory to save the generated image.
        generation_count (int): The current generation count, used for naming the file.

    Returns:
        str or None: The path of the newly generated image on success, otherwise None.
    """
    print(f"\n[Tool Call] Imagining new frame with prompt: '{prompt}'")
    ark_api_key = os.environ.get("ARK_API_KEY")
    if not ark_api_key:
        raise ValueError("Error: Environment variable ARK_API_KEY is not set.")

    client = Ark(
        base_url="https://ark.cn-beijing.volces.com/api/v3",
        api_key=ark_api_key,
    )

    ref_image_path = all_frame_paths.get(reference_image_id)
    if not ref_image_path or not os.path.exists(ref_image_path):
        raise FileNotFoundError(f"Reference image ID not found: {reference_image_id}")

    try:
        # Encode the reference image to a Base64 Data URI
        ref_image_b64 = encode_image(ref_image_path)
        ref_image_data_uri = f"data:image/jpeg;base64,{ref_image_b64}"

        imagesResponse = client.images.generate(
            model="doubao-seedream-4-0-250828",
            prompt=prompt,
            image=ref_image_data_uri,
            size="1024x1024",  # Can be adjusted as needed, e.g., "2K"
            response_format="url",
            watermark=False,
        )

        image_url = imagesResponse.data[0].url

        # Download the image from the URL
        response = requests.get(image_url)
        response.raise_for_status()

        # Save the image to the specified directory
        new_frame_filename = (
            f"generated_frame_{generation_count}_ref_{reference_image_id}.jpg"
        )
        new_frame_path = os.path.join(output_dir, new_frame_filename)

        with open(new_frame_path, "wb") as f:
            f.write(response.content)

        print(f"[Tool Success] Generated frame saved to: {new_frame_path}")
        return new_frame_path

    except Exception as e:
        print(f"An error occurred during image generation or download: {e}")
        traceback.print_exc()
        return None


def parse_arguments():
    """Parse command-line arguments"""
    parser = argparse.ArgumentParser(
        description="Video QA Evaluation Framework with Imagine-and-Reason Agent"
    )
    parser.add_argument(
        "--target-model",
        "-tm",
        type=str,
        required=True,
        help="The model to be evaluated (e.g., gpt-4o)",
    )
    parser.add_argument(
        "--frames-path",
        "-fp",
        type=str,
        required=True,
        help="Absolute path to the root directory containing video frames.",
    )
    parser.add_argument(
        "--output-path",
        "-op",
        type=str,
        default="./generated_outputs",
        help="Path to store generated images and results.",
    )
    parser.add_argument(
        "--data-file",
        "-df",
        type=str,
        required=True,
        help="Absolute path to the evaluation dataset JSON file.",
    )
    parser.add_argument(
        "--initial-frames-num",
        "-ifn",
        type=int,
        default=8,
        help="Number of initial uniformly sampled frames.",
    )
    parser.add_argument(
        "--max-retry-times",
        "-mr",
        type=int,
        default=10,
        help="Maximum number of retries for failed API calls.",
    )
    parser.add_argument(
        "--pool-processes",
        "-pp",
        type=int,
        default=10,
        help="Number of parallel processes.",
    )
    parser.add_argument(
        "--base_url",
        type=str,
        required=True,
        help="API Endpoint URL for the target model service.",
    )
    parser.add_argument(
        "--api_key",
        type=str,
        required=True,
        help="API Key for the target model service.",
    )
    return parser.parse_args()


def save_json_file(data, output_file):
    """Save data to a JSON file"""
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)


def extract_json_from_response(response):
    """Extract JSON answer from the model's text response"""
    if not response:
        return None
    match = re.search(r"```json\s*(\{.*?\})\s*```", response, re.DOTALL)
    if match:
        try:
            return json.loads(match.group(1))
        except (json.JSONDecodeError, IndexError):
            return None
    return None


def calculate_metrics(results):
    """Calculate various metrics from the evaluation results"""
    valid_results = [r for r in results if "error" not in r]
    total_samples = len(valid_results)
    if total_samples == 0:
        return {
            "total_samples": 0,
            "answered_samples": 0,
            "correct_answers": 0,
            "accuracy": 0.0,
        }
    answered_samples = sum(
        1 for x in valid_results if x.get("model_answer") is not None
    )
    correct_answers = sum(1 for x in valid_results if x.get("is_correct"))
    accuracy = correct_answers / answered_samples if answered_samples > 0 else 0.0
    return {
        "total_samples": total_samples,
        "answered_samples": answered_samples,
        "correct_answers": correct_answers,
        "accuracy": accuracy,
    }


def call_single_model(client, messages, model, item_id, max_retry_times, tools=None):
    """A single model API call with retry logic"""
    params = {"model": model, "messages": messages, "max_tokens": 4096}
    if tools:
        params["tools"] = tools
        params["tool_choice"] = "auto"

    retry_times = 0
    while retry_times < max_retry_times:
        try:
            completion = client.chat.completions.create(**params)
            return completion.choices[0].message
        except Exception as e:
            retry_times += 1
            print(
                f"API call error (Item {item_id}): {str(e)}. Retrying ({retry_times}/{max_retry_times})..."
            )
            if retry_times == max_retry_times:
                raise e
            time.sleep(5)


def uniformly_sample_frames_and_encode(frames_dir, num_frames):
    """Uniformly sample a specified number of frames from a directory and encode them"""
    if not os.path.isdir(frames_dir):
        return [], {}

    frame_files = sorted(
        [f for f in os.listdir(frames_dir) if f.endswith(".jpg")],
        key=lambda x: int(re.search(r"frame_(\d+)\.jpg", x).group(1)),
    )

    total_frames = len(frame_files)
    if total_frames == 0:
        return [], {}

    if total_frames > num_frames:
        indices = [int(i * total_frames / num_frames) for i in range(num_frames)]
        sampled_files = [frame_files[i] for i in indices]
    else:
        sampled_files = frame_files

    frame_path_map = {}
    encoded_frames = []
    for f in sampled_files:
        path = os.path.join(frames_dir, f)
        frame_id = int(re.search(r"frame_(\d+)\.jpg", f).group(1))
        b64_image = encode_image(path)
        # Send frame ID and image content as a pair
        encoded_frames.append({"type": "text", "text": f"This is Frame ID: {frame_id}"})
        encoded_frames.append(
            {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{b64_image}"},
            }
        )
        frame_path_map[frame_id] = path

    return encoded_frames, frame_path_map


def evaluate_single_item_agentic_imagination(
    data_item,
    initial_frames,
    initial_frame_paths,
    generated_images_dir,
    target_model,
    api_key,
    base_url,
    max_retry_times,
):
    """
    Core logic for evaluating a single data item using the Imagine-and-Reason Agent.
    """
    # 4. New Agent Loop
    if "ark" in base_url:
        client = Ark(base_url=base_url, api_key=api_key)
    elif "aliyun" in base_url or "127.0.0.1" in base_url:
        client = OpenAI(api_key=api_key, base_url=base_url)
    else:
        client = AzureOpenAI(
            api_version="2023-05-15", api_key=api_key, azure_endpoint=base_url
        )

    tools = [IMAGINE_FRAME_TOOL_SCHEMA]

    # Store paths of all available frames (initial + generated) in a dictionary for reference
    available_frame_paths = initial_frame_paths.copy()

    initial_prompt_content = [
        {
            "type": "text",
            "text": "Here are the initial sampled video frames provided to you:",
        },
        *initial_frames,
        {
            "type": "text",
            "text": f"Please answer the following question:\n{data_item['question']}",
        },
    ]

    messages = [
        {"role": "system", "content": IMAGINE_AGENT_SYSTEM_PROMPT},
        {"role": "user", "content": initial_prompt_content},
    ]

    response_content = None
    max_tool_calls = (
        5  # Limit the number of times the agent can imagine to prevent infinite loops
    )
    generation_count = 0

    for i in range(max_tool_calls):
        response_message = call_single_model(
            client,
            messages,
            target_model,
            data_item["key"],
            max_retry_times,
            tools=tools,
        )
        if response_message is None:
            return None

        messages.append(response_message.model_dump(exclude_none=True))

        if response_message.tool_calls:
            tool_call = response_message.tool_calls[
                0
            ]  # Process one tool call at a time
            function_name = tool_call.function.name

            if function_name == "imagine_frame":
                generation_count += 1
                function_args = json.loads(tool_call.function.arguments)
                new_frame_path = imagine_frame(
                    **function_args,
                    all_frame_paths=available_frame_paths,
                    output_dir=generated_images_dir,
                    generation_count=generation_count,
                )

                if new_frame_path:
                    # Create a unique ID for the newly generated frame
                    new_frame_id = (
                        max(available_frame_paths.keys())
                        if available_frame_paths
                        else 0
                    ) + 1
                    available_frame_paths[new_frame_id] = new_frame_path

                    b64_image = encode_image(new_frame_path)
                    tool_response_content = [
                        {
                            "type": "text",
                            "text": f"Here is the frame you requested to imagine (New Frame ID: {new_frame_id}). Please use it to continue your reasoning.",
                        },
                        {
                            "type": "image_url",
                            "image_url": {"url": f"data:image/jpeg;base64,{b64_image}"},
                        },
                    ]

                    messages.append(
                        {
                            "tool_call_id": tool_call.id,
                            "role": "tool",
                            "name": function_name,
                            "content": json.dumps(
                                {"status": "success", "new_frame_id": new_frame_id}
                            ),
                        }
                    )
                    messages.append({"role": "user", "content": tool_response_content})
                else:  # Tool execution failed
                    messages.append(
                        {
                            "tool_call_id": tool_call.id,
                            "role": "tool",
                            "name": function_name,
                            "content": json.dumps(
                                {
                                    "status": "error",
                                    "message": "Failed to generate image.",
                                }
                            ),
                        }
                    )
        else:  # No tool call means the model is ready to give a final answer
            response_content = response_message.content
            break

    # If the max number of calls is reached without an answer, force a final response
    if response_content is None and response_message:
        final_prompt = "You have reached the maximum number of tool calls. Please provide a final answer in the specified JSON format based on the information you have gathered so far."
        messages.append({"role": "user", "content": final_prompt})
        final_response_message = call_single_model(
            client, messages, target_model, data_item["key"], max_retry_times
        )
        if final_response_message:
            messages.append(final_response_message.model_dump(exclude_none=True))
            response_content = final_response_message.content

    is_correct = False
    model_answer_cleaned = None
    parsed_json = extract_json_from_response(response_content)
    if parsed_json and "answer" in parsed_json:
        model_answer_cleaned = str(parsed_json["answer"]).strip().upper()
        gold_answer = data_item["answer"].strip().upper()
        if model_answer_cleaned == gold_answer:
            is_correct = True

    return {
        **data_item,
        "agent_conversation": messages,
        "model_reasoning_and_answer": response_content,
        "model_answer": model_answer_cleaned,
        "is_correct": is_correct,
        "generated_images_path": generated_images_dir,  # 5. Store the path to intermediate generated images
    }


def encode_image(image_path):
    """Encode an image file to a Base64 string"""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def process_single_data(data_item, args):
    """Worker function to process a single data item in parallel"""
    item_key = data_item["key"]
    try:
        # Create a separate subfolder for each video's generated images
        generated_images_dir = os.path.join(
            args.output_path, "generated_images", item_key
        )
        os.makedirs(generated_images_dir, exist_ok=True)

        specific_frames_path = os.path.join(args.frames_path, item_key)
        initial_frames, initial_frame_paths = uniformly_sample_frames_and_encode(
            specific_frames_path, args.initial_frames_num
        )

        if not initial_frames:
            raise FileNotFoundError(f"Initial frames not found for item '{item_key}'")

        result = evaluate_single_item_agentic_imagination(
            data_item,
            initial_frames,
            initial_frame_paths,
            generated_images_dir,
            args.target_model,
            args.api_key,
            args.base_url,
            args.max_retry_times,
        )
        return result

    except Exception as e:
        print(f"\nA critical error occurred while processing item {item_key}: {str(e)}")
        traceback.print_exc()
        return {
            "key": item_key,
            "uid": data_item.get("uid"),
            "error": str(e),
            "traceback": traceback.format_exc(),
        }


def load_test_data(json_file):
    """Load test data from a JSON file"""
    try:
        with open(json_file, "r", encoding="utf-8") as f:
            return json.load(f)
    except FileNotFoundError:
        print(f"Error: Data file not found: {json_file}")
        exit(1)
    except json.JSONDecodeError:
        print(f"Error: JSON file is malformed: {json_file}")
        exit(1)


def main():
    """Main function to orchestrate the entire evaluation flow"""
    args = parse_arguments()

    print("--- Video QA Imagine-and-Reason Agent Framework ---")
    print(f"Evaluating Model: {args.target_model}")
    print(f"Output Path: {args.output_path}")
    print(f"Dataset: {args.data_file}")
    print("---------------------------------")

    # Create the main output directory
    os.makedirs(args.output_path, exist_ok=True)

    model_name_safe = args.target_model.replace("/", "_")
    data_filename_base = os.path.splitext(os.path.basename(args.data_file))[0]

    output_prefix = f"{model_name_safe}_{data_filename_base}_imagine_agent"
    results_output_file = os.path.join(
        args.output_path, f"{output_prefix}_results.json"
    )
    metrics_output_file = os.path.join(
        args.output_path, f"{output_prefix}_metrics.json"
    )
    error_log_file = os.path.join(args.output_path, f"{output_prefix}_errors.log")

    # The logic for resuming from a checkpoint can be added here, same as in the first script

    all_test_data = load_test_data(args.data_file)
    tasks_to_process = all_test_data

    all_results = []
    # Use ProcessPoolExecutor for parallel processing
    with concurrent.futures.ProcessPoolExecutor(
        max_workers=args.pool_processes
    ) as executor:
        func = partial(process_single_data, args=args)
        results_iterator = executor.map(func, tasks_to_process)

        for result in tqdm(
            results_iterator, total=len(tasks_to_process), desc="Processing Videos"
        ):
            if result:
                if "error" in result:
                    with open(error_log_file, "a", encoding="utf-8") as f:
                        f.write(
                            f"Error on item {result.get('key', 'N/A')}:\n  Error: {result['error']}\n---\n"
                        )
                all_results.append(result)

                # Save results every 10 videos to prevent data loss from interruptions
                if len(all_results) % 10 == 0:
                    save_json_file(all_results, results_output_file)

    print("\n\nProcessing complete.")
    # Save the final complete results
    save_json_file(all_results, results_output_file)
    print(f"Detailed results saved to: {results_output_file}")

    # Calculate and save the final metrics
    final_metrics = calculate_metrics(all_results)
    save_json_file(final_metrics, metrics_output_file)
    print(f"\nEvaluation metrics saved to: {metrics_output_file}")
    print(json.dumps(final_metrics, indent=4))


if __name__ == "__main__":
    # Before running this script, please ensure you have set the environment variable in your terminal:
    # export ARK_API_KEY="YOUR_VOLCENGINE_ARK_API_KEY"
    main()