File size: 16,954 Bytes
d686824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

import torch
import os
from tqdm import tqdm
from typing import Dict, Optional, Sequence, List
import transformers
import re

import openai
from typing import List, Dict
from PIL import Image
import base64
import io
try:
    import cv2
except ImportError:
    cv2 = None
    print("Warning: OpenCV is not installed, video frame extraction will not work.")
from TStar.utils import *

class LlavaInterface:
    """
    示例:封装对 Llava 模型的推理调用。
    关键在于对外暴露统一的方法 inference(query, frames, **kwargs)。
    """
    def __init__(self, model_path: str, model_base: Optional[str] = None):
        # 这里是加载 Llava 模型等的逻辑
        # self.tokenizer, self.model = ...
        self.model_path = model_path
        self.model_base = model_base
        print(f"[LlavaInterface] model_path={model_path}, model_base={model_base}")

    def inference(
        self,
        query: str,
        frames: Optional[List[Image.Image]] = None,
        system_message: str = "You are a helpful assistant.",
        temperature: float = 0.2,
        top_p: Optional[float] = None,
        num_beams: int = 1,
        max_tokens: int = 512,
        **kwargs
    ) -> str:
        """
        对外暴露统一的推理接口。

        query: 用户输入,可能包含文本+<image>标记
        frames: 对应的图像帧列表
        system_message: 系统提示
        其它参数根据需要自行添加
        """
        
        # 模拟推理逻辑,需要你自己实现
        print("[LlavaInterface] Inference called with query:", query)
        print("[LlavaInterface] frames count:", len(frames) if frames else 0)

        # 真实场景下,你会调用 Llava 模型进行推理
        return "Fake Response from LlavaInterface"
    
class GPT4Interface:
    def __init__(self,model="gpt-4o", api_key=None):
        """
        Initialize the GPT-4 API client.

        Reads the OpenAI API key from the environment variable `OPENAI_API_KEY`.
        """
        self.api_key = api_key
        self.model_name = model
        if api_key==None:
            self.api_key = os.getenv("OPENAI_API_KEY")
        if not self.api_key:
            raise ValueError("Environment variable OPENAI_API_KEY is not set.")
        openai.api_key = self.api_key

    def inference_text_only(self, query: str, system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str:
        """
        Perform inference using the GPT-4 API.

        Args:
            query (str): User's query or input.
            system_message (str): System message to guide the model's behavior.
            temperature (float): Sampling temperature for the response.
            max_tokens (int): Maximum number of tokens for the response.

        Returns:
            str: The response generated by the GPT-4 model.
        """
        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": query},
        ]

        try:
            response = openai.chat.completions.create(
                model=self.model_name,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
            )
            
            return response.choices[0].message.content.strip()
        except Exception as e:
            return f"Error: {str(e)}"
    
    def inference_with_frames(self, query: str, frames: List[Image.Image], system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str:
        """
        Perform inference using the GPT-4 API with video frames as context.

        Args:
            query (str): User's query or input.
            frames (List[Image.Image]): List of PIL.Image objects to provide visual context.
            system_message (str): System message to guide the model's behavior.
            temperature (float): Sampling temperature for the response.
            max_tokens (int): Maximum number of tokens for the response.

        Returns:
            str: The response generated by the GPT-4 model.
        """

        # Messages format
        inputs = [{"type": "text", "text": query}]

        # Encode frames as Base64 strings
        for i, frame in enumerate(frames):
            try:
                # Convert PIL Image to Base64 string
                frame_base64 = encode_image_to_base64(frame)
                visual_context = {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{frame_base64}",
                            "detail": "low"
                            }
                    }
                # Adding visual context (images) to messages if supported by the model
                inputs.append(visual_context)

            except Exception as e:
                return f"Error encoding frame {i}: {str(e)}"

        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": inputs},
        ]
        try:
            response = openai.chat.completions.create(
                model=self.model_name,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            return f"Error: {str(e)}"

    

    def inference_qa(self, question: str, options: str, frames: List[Image.Image] = None, system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 500) -> str:
        """
        Perform inference for a multiple-choice question with optional visual frames as context.

        Args:
            question (str): The question to answer.
            options (str): Multiple-choice options formatted as a string.
            frames (List[Image.Image], optional): List of PIL.Image objects to provide additional visual context.
            system_message (str): System message to guide the model's behavior.
            temperature (float): Sampling temperature for the response.
            max_tokens (int): Maximum number of tokens for the response.

        Returns:
            str: The selected option or answer.
        """
        # Construct query
        query = f"Question: {question}\nOptions: {options}\nAnswer with the letter corresponding to the best choice."

        # Messages format
        inputs = [{"type": "text", "text": query}]

        if frames:
            # Encode frames as Base64 strings
            for i, frame in enumerate(frames):
                try:
                    frame_base64 = encode_image_to_base64(frame)
                    visual_context = {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{frame_base64}",
                            "detail": "low"
                        }
                    }
                    # Adding visual context (images) to messages if supported by the model
                    inputs.append(visual_context)

                except Exception as e:
                    return f"Error encoding frame {i}: {str(e)}"

        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": inputs},
        ]

        try:
            response = openai.chat.completions.create(
                model="gpt-4o",
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            return f"Error: {str(e)}"

    def inference_with_frames_all_in_one(self, query: str, frames: List[Image.Image], system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str:
        """
        Perform inference using the GPT-4 API with video frames as context.
        Args:
            query (str): User's query or input. image tag: <image>
            frames (List[Image.Image]): List of PIL.Image objects to provide visual context.
            system_message (str): System message to guide the model's behavior.
            temperature (float): Sampling temperature for the response.
            max_tokens (int): Maximum number of tokens for the response.
        Returns:
            str: The response generated by the GPT-4 model.
        """
        # Split query by <image>
        parts = query.split("<image>")
        inputs = []

        # Add text and images alternately to inputs
        for i, part in enumerate(parts):
            if part.strip():
                inputs.append({"type": "text", "text": part.strip()})
            if i < len(frames):  # Ensure we don't exceed the number of available frames
                try:
                    frame_base64 = encode_image_to_base64(frames[i])
                    visual_context = {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{frame_base64}",
                            "detail": "low"
                        }
                    }
                    inputs.append(visual_context)
                except Exception as e:
                    return f"Error encoding frame {i}: {str(e)}"

        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": inputs},
        ]

        try:
            response = openai.chat.completions.create(
                model=self.model_name,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens
            )
            return response.choices[0].message.content.strip()
        except Exception as e:
            return f"Error: {str(e)}"

class TStarUniversalGrounder:
    """
    结合了原先 TStarGrounder 与 TStarGPTGrounder 的功能,
    可以通过 backend 参数切换到底层使用的是 LlavaInterface 还是 GPT4Interface。
    """
    def __init__(
        self,
        backend: str = "gpt",
        model_name: str = "gpt-4o",
        model_path: Optional[str] = None,
        model_base: Optional[str] = None,
        gpt4_api_key: Optional[str] = None,
        num_frames: Optional[int] = 8,
    ):
        """
        backend: "llava" 或 "gpt4"
        model_path, model_base: Llava 模型的路径及版本
        gpt4_model_name, gpt4_api_key: GPT4 的模型名称及 API Key
        """
        self.backend = backend.lower()
        self.num_frames = num_frames
        if self.backend == "llava":
            # 初始化 LlavaInterface
            if not model_path:
                raise ValueError("Please provide model_path for LlavaInterface")
            self.VLM_model_interfance = LlavaInterface(model_path=model_path, model_base=model_base)
        elif self.backend == "gpt4":
            # 初始化 GPT4Interface
            self.VLM_model_interfance = GPT4Interface(model=model_name, api_key=gpt4_api_key)
        else:
            raise ValueError("backend must be either 'llava' or 'gpt4'.")

    def inference_query_grounding(
        self,
        video_path: str,
        question: str,
        options: Optional[str] = None,
        temperature: float = 0.0,
        max_tokens: int = 512
    ) -> Dict[str, List[str]]:
        """
        识别可作为答案依据的 target_objects 和可能辅助判断的 cue_objects。
        """

        frames = load_video_frames(video_path=video_path, num_frames=self.num_frames)
        # 构建 prompt
        system_prompt = (
            "Here is a video:\n"
            + "\n".join(["<image>"] * len(frames))  
            + "\nHere is a question about the video:\n"
            f"Question: {question}\n"
        )
        if options:
            system_prompt += f"Options: {options}\n"
        system_prompt += (
            "\nWhen answering this question about the video:\n"
            "1. What key objects to locate the answer?\n"
            "   - List potential key objects (short sentences, separated by commas).\n"
            "2. What cue objects might be near the key objects and might appear in the scenes?\n"
            "   - List potential cue objects (short sentences, separated by commas).\n\n"
            "Please provide your answer in two lines, directly listing the key and cue objects, separated by commas."
        )

        # 统一走 self.interface.inference # need more abstract function
        response = self.VLM_model_interfance.inference_with_frames_all_in_one(
            query=system_prompt,
            frames=frames,
            temperature=temperature,
            max_tokens=max_tokens,
        )

        # 根据预期格式解析响应
        lines = response.split("\n")
        if len(lines) < 2:
            # print(response)
            raise ValueError(f"Unexpected response format from inference_query_grounding() --> {response}.")

        target_objects = [self.check_objects_str(obj) for obj in lines[0].split(",") if obj.strip()]
        cue_objects = [self.check_objects_str(obj) for obj in lines[1].split(",") if obj.strip()]
        
        return target_objects, cue_objects
    def check_objects_str(self, obj: str):
        obj = obj.lower() #小写
        obj = obj.strip().replace("1. ", "")
        obj = obj.strip().replace("2. ", "") 
        obj = obj.strip().replace(".", "")
        obj = obj.strip().replace("key objects: ", "")
        obj = obj.strip().replace("cue objects: ", "")
        obj = obj.strip().replace(": ", "")

        return obj




    def inference_qa(
        self,
        frames: List[Image.Image],
        question: str,
        options: str,
        temperature: float = 0.2,
        max_tokens: int = 128
    ) -> str:
        """
        多选推理,返回最可能的选项(如 A、B、C、D)。
        """
        system_prompt = (
            "Select the best answer to the following multiple-choice question based on the video.\n"
            + "\n".join(["<image>"] * len(frames))  
            + f"\nQuestion: {question}\n"
            + f"Options: {options}\n\n"
            "Answer with the option’s letter from the given choices directly."
        )

        response = self.VLM_model_interfance.inference_with_frames_all_in_one(
            query=system_prompt,
            frames=frames,
            temperature=temperature,
            max_tokens=30
        )
        return response.strip()

    def inference_openend_qa(
        self,
        frames: List[Image.Image],
        question: str,
        # options: str,
        temperature: float = 0.2,
        max_tokens: int = 2048
    ) -> str:
        """
        多选推理,返回最可能的选项(如 A、B、C、D)。
        """
        system_prompt = (
            "Answer with the question in short based on the video.\n"
            + "\n".join(["<image>"] * len(frames))  
            + f"\nQuestion: {question}\n"
        )

        response = self.VLM_model_interfance.inference_with_frames_all_in_one(
            query=system_prompt,
            frames=frames,
            temperature=temperature,
            max_tokens=30
        )
        return response.strip()

if __name__ == "__main__":
    """
    测试示例。
    """
    # 1) 使用 Llava 作为底层模型
    # print("=== Using Llava backend ===")
    # llava_grounder = TStarUniversalGrounder(
    #     backend="llava",
    #     model_path="/path/to/llava",
    #     model_base="v1.0"
    # )
    frames_fake = [ # 随记噪声更好
        Image.open("./output_image.jpg"),
        Image.open("/home/yejinhui/Projects/VisualSearch/output_image.jpg")
    ]
    # result_grounding_llava = llava_grounder.inference_query_grounding(
    #     frames=frames_fake,
    #     question="What objects are in the video?",
    # )
    # print("Llava Grounding Result:", result_grounding_llava)

    # 2) 使用 GPT-4 作为底层模型
    print("\n=== Using GPT-4 backend ===")
    gpt4_grounder = TStarUniversalGrounder(
        backend="gpt4",
        model_name="gpt-4o",
        gpt4_api_key=None
    )
    searchable_objects = gpt4_grounder.inference_query_grounding(
        frames=frames_fake,
        question="What objects are in the video?"
    )
    print("GPT-4 Grounding Result:", searchable_objects)

    # 3) 多选问答示例
    question_mc = "How many cats can be seen?\n"
    options_mc = "A) 0\nB) 1\nC) 2\nD) 3\n"
    # answer_llava = llava_grounder.inference_qa(frames_fake, question_mc, options_mc)
    # print("Llava QA Answer:", answer_llava)

    answer_gpt4 = gpt4_grounder.inference_qa(frames_fake, question_mc, options_mc)
    print("GPT-4 QA Answer:", answer_gpt4)