import torch import os from tqdm import tqdm from typing import Dict, Optional, Sequence, List import transformers import re import openai from typing import List, Dict from PIL import Image import base64 import io try: import cv2 except ImportError: cv2 = None print("Warning: OpenCV is not installed, video frame extraction will not work.") from TStar.utils import * class LlavaInterface: """ 示例:封装对 Llava 模型的推理调用。 关键在于对外暴露统一的方法 inference(query, frames, **kwargs)。 """ def __init__(self, model_path: str, model_base: Optional[str] = None): # 这里是加载 Llava 模型等的逻辑 # self.tokenizer, self.model = ... self.model_path = model_path self.model_base = model_base print(f"[LlavaInterface] model_path={model_path}, model_base={model_base}") def inference( self, query: str, frames: Optional[List[Image.Image]] = None, system_message: str = "You are a helpful assistant.", temperature: float = 0.2, top_p: Optional[float] = None, num_beams: int = 1, max_tokens: int = 512, **kwargs ) -> str: """ 对外暴露统一的推理接口。 query: 用户输入,可能包含文本+标记 frames: 对应的图像帧列表 system_message: 系统提示 其它参数根据需要自行添加 """ # 模拟推理逻辑,需要你自己实现 print("[LlavaInterface] Inference called with query:", query) print("[LlavaInterface] frames count:", len(frames) if frames else 0) # 真实场景下,你会调用 Llava 模型进行推理 return "Fake Response from LlavaInterface" class GPT4Interface: def __init__(self,model="gpt-4o", api_key=None): """ Initialize the GPT-4 API client. Reads the OpenAI API key from the environment variable `OPENAI_API_KEY`. """ self.api_key = api_key self.model_name = model if api_key==None: self.api_key = os.getenv("OPENAI_API_KEY") if not self.api_key: raise ValueError("Environment variable OPENAI_API_KEY is not set.") openai.api_key = self.api_key def inference_text_only(self, query: str, system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str: """ Perform inference using the GPT-4 API. Args: query (str): User's query or input. system_message (str): System message to guide the model's behavior. temperature (float): Sampling temperature for the response. max_tokens (int): Maximum number of tokens for the response. Returns: str: The response generated by the GPT-4 model. """ messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": query}, ] try: response = openai.chat.completions.create( model=self.model_name, messages=messages, temperature=temperature, max_tokens=max_tokens, ) return response.choices[0].message.content.strip() except Exception as e: return f"Error: {str(e)}" def inference_with_frames(self, query: str, frames: List[Image.Image], system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str: """ Perform inference using the GPT-4 API with video frames as context. Args: query (str): User's query or input. frames (List[Image.Image]): List of PIL.Image objects to provide visual context. system_message (str): System message to guide the model's behavior. temperature (float): Sampling temperature for the response. max_tokens (int): Maximum number of tokens for the response. Returns: str: The response generated by the GPT-4 model. """ # Messages format inputs = [{"type": "text", "text": query}] # Encode frames as Base64 strings for i, frame in enumerate(frames): try: # Convert PIL Image to Base64 string frame_base64 = encode_image_to_base64(frame) visual_context = { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{frame_base64}", "detail": "low" } } # Adding visual context (images) to messages if supported by the model inputs.append(visual_context) except Exception as e: return f"Error encoding frame {i}: {str(e)}" messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": inputs}, ] try: response = openai.chat.completions.create( model=self.model_name, messages=messages, temperature=temperature, max_tokens=max_tokens ) return response.choices[0].message.content.strip() except Exception as e: return f"Error: {str(e)}" def inference_qa(self, question: str, options: str, frames: List[Image.Image] = None, system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 500) -> str: """ Perform inference for a multiple-choice question with optional visual frames as context. Args: question (str): The question to answer. options (str): Multiple-choice options formatted as a string. frames (List[Image.Image], optional): List of PIL.Image objects to provide additional visual context. system_message (str): System message to guide the model's behavior. temperature (float): Sampling temperature for the response. max_tokens (int): Maximum number of tokens for the response. Returns: str: The selected option or answer. """ # Construct query query = f"Question: {question}\nOptions: {options}\nAnswer with the letter corresponding to the best choice." # Messages format inputs = [{"type": "text", "text": query}] if frames: # Encode frames as Base64 strings for i, frame in enumerate(frames): try: frame_base64 = encode_image_to_base64(frame) visual_context = { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{frame_base64}", "detail": "low" } } # Adding visual context (images) to messages if supported by the model inputs.append(visual_context) except Exception as e: return f"Error encoding frame {i}: {str(e)}" messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": inputs}, ] try: response = openai.chat.completions.create( model="gpt-4o", messages=messages, temperature=temperature, max_tokens=max_tokens ) return response.choices[0].message.content.strip() except Exception as e: return f"Error: {str(e)}" def inference_with_frames_all_in_one(self, query: str, frames: List[Image.Image], system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str: """ Perform inference using the GPT-4 API with video frames as context. Args: query (str): User's query or input. image tag: frames (List[Image.Image]): List of PIL.Image objects to provide visual context. system_message (str): System message to guide the model's behavior. temperature (float): Sampling temperature for the response. max_tokens (int): Maximum number of tokens for the response. Returns: str: The response generated by the GPT-4 model. """ # Split query by parts = query.split("") inputs = [] # Add text and images alternately to inputs for i, part in enumerate(parts): if part.strip(): inputs.append({"type": "text", "text": part.strip()}) if i < len(frames): # Ensure we don't exceed the number of available frames try: frame_base64 = encode_image_to_base64(frames[i]) visual_context = { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{frame_base64}", "detail": "low" } } inputs.append(visual_context) except Exception as e: return f"Error encoding frame {i}: {str(e)}" messages = [ {"role": "system", "content": system_message}, {"role": "user", "content": inputs}, ] try: response = openai.chat.completions.create( model=self.model_name, messages=messages, temperature=temperature, max_tokens=max_tokens ) return response.choices[0].message.content.strip() except Exception as e: return f"Error: {str(e)}" class TStarUniversalGrounder: """ 结合了原先 TStarGrounder 与 TStarGPTGrounder 的功能, 可以通过 backend 参数切换到底层使用的是 LlavaInterface 还是 GPT4Interface。 """ def __init__( self, backend: str = "gpt", model_name: str = "gpt-4o", model_path: Optional[str] = None, model_base: Optional[str] = None, gpt4_api_key: Optional[str] = None, num_frames: Optional[int] = 8, ): """ backend: "llava" 或 "gpt4" model_path, model_base: Llava 模型的路径及版本 gpt4_model_name, gpt4_api_key: GPT4 的模型名称及 API Key """ self.backend = backend.lower() self.num_frames = num_frames if self.backend == "llava": # 初始化 LlavaInterface if not model_path: raise ValueError("Please provide model_path for LlavaInterface") self.VLM_model_interfance = LlavaInterface(model_path=model_path, model_base=model_base) elif self.backend == "gpt4": # 初始化 GPT4Interface self.VLM_model_interfance = GPT4Interface(model=model_name, api_key=gpt4_api_key) else: raise ValueError("backend must be either 'llava' or 'gpt4'.") def inference_query_grounding( self, video_path: str, question: str, options: Optional[str] = None, temperature: float = 0.0, max_tokens: int = 512 ) -> Dict[str, List[str]]: """ 识别可作为答案依据的 target_objects 和可能辅助判断的 cue_objects。 """ frames = load_video_frames(video_path=video_path, num_frames=self.num_frames) # 构建 prompt system_prompt = ( "Here is a video:\n" + "\n".join([""] * len(frames)) + "\nHere is a question about the video:\n" f"Question: {question}\n" ) if options: system_prompt += f"Options: {options}\n" system_prompt += ( "\nWhen answering this question about the video:\n" "1. What key objects to locate the answer?\n" " - List potential key objects (short sentences, separated by commas).\n" "2. What cue objects might be near the key objects and might appear in the scenes?\n" " - List potential cue objects (short sentences, separated by commas).\n\n" "Please provide your answer in two lines, directly listing the key and cue objects, separated by commas." ) # 统一走 self.interface.inference # need more abstract function response = self.VLM_model_interfance.inference_with_frames_all_in_one( query=system_prompt, frames=frames, temperature=temperature, max_tokens=max_tokens, ) # 根据预期格式解析响应 lines = response.split("\n") if len(lines) < 2: # print(response) raise ValueError(f"Unexpected response format from inference_query_grounding() --> {response}.") target_objects = [self.check_objects_str(obj) for obj in lines[0].split(",") if obj.strip()] cue_objects = [self.check_objects_str(obj) for obj in lines[1].split(",") if obj.strip()] return target_objects, cue_objects def check_objects_str(self, obj: str): obj = obj.lower() #小写 obj = obj.strip().replace("1. ", "") obj = obj.strip().replace("2. ", "") obj = obj.strip().replace(".", "") obj = obj.strip().replace("key objects: ", "") obj = obj.strip().replace("cue objects: ", "") obj = obj.strip().replace(": ", "") return obj def inference_qa( self, frames: List[Image.Image], question: str, options: str, temperature: float = 0.2, max_tokens: int = 128 ) -> str: """ 多选推理,返回最可能的选项(如 A、B、C、D)。 """ system_prompt = ( "Select the best answer to the following multiple-choice question based on the video.\n" + "\n".join([""] * len(frames)) + f"\nQuestion: {question}\n" + f"Options: {options}\n\n" "Answer with the option’s letter from the given choices directly." ) response = self.VLM_model_interfance.inference_with_frames_all_in_one( query=system_prompt, frames=frames, temperature=temperature, max_tokens=30 ) return response.strip() def inference_openend_qa( self, frames: List[Image.Image], question: str, # options: str, temperature: float = 0.2, max_tokens: int = 2048 ) -> str: """ 多选推理,返回最可能的选项(如 A、B、C、D)。 """ system_prompt = ( "Answer with the question in short based on the video.\n" + "\n".join([""] * len(frames)) + f"\nQuestion: {question}\n" ) response = self.VLM_model_interfance.inference_with_frames_all_in_one( query=system_prompt, frames=frames, temperature=temperature, max_tokens=30 ) return response.strip() if __name__ == "__main__": """ 测试示例。 """ # 1) 使用 Llava 作为底层模型 # print("=== Using Llava backend ===") # llava_grounder = TStarUniversalGrounder( # backend="llava", # model_path="/path/to/llava", # model_base="v1.0" # ) frames_fake = [ # 随记噪声更好 Image.open("./output_image.jpg"), Image.open("/home/yejinhui/Projects/VisualSearch/output_image.jpg") ] # result_grounding_llava = llava_grounder.inference_query_grounding( # frames=frames_fake, # question="What objects are in the video?", # ) # print("Llava Grounding Result:", result_grounding_llava) # 2) 使用 GPT-4 作为底层模型 print("\n=== Using GPT-4 backend ===") gpt4_grounder = TStarUniversalGrounder( backend="gpt4", model_name="gpt-4o", gpt4_api_key=None ) searchable_objects = gpt4_grounder.inference_query_grounding( frames=frames_fake, question="What objects are in the video?" ) print("GPT-4 Grounding Result:", searchable_objects) # 3) 多选问答示例 question_mc = "How many cats can be seen?\n" options_mc = "A) 0\nB) 1\nC) 2\nD) 3\n" # answer_llava = llava_grounder.inference_qa(frames_fake, question_mc, options_mc) # print("Llava QA Answer:", answer_llava) answer_gpt4 = gpt4_grounder.inference_qa(frames_fake, question_mc, options_mc) print("GPT-4 QA Answer:", answer_gpt4)