Spaces:
Runtime error
Runtime error
| import torch | |
| import os | |
| from tqdm import tqdm | |
| from typing import Dict, Optional, Sequence, List | |
| import transformers | |
| import re | |
| import openai | |
| from typing import List, Dict | |
| from PIL import Image | |
| import base64 | |
| import io | |
| try: | |
| import cv2 | |
| except ImportError: | |
| cv2 = None | |
| print("Warning: OpenCV is not installed, video frame extraction will not work.") | |
| from TStar.utils import * | |
| class LlavaInterface: | |
| """ | |
| 示例:封装对 Llava 模型的推理调用。 | |
| 关键在于对外暴露统一的方法 inference(query, frames, **kwargs)。 | |
| """ | |
| def __init__(self, model_path: str, model_base: Optional[str] = None): | |
| # 这里是加载 Llava 模型等的逻辑 | |
| # self.tokenizer, self.model = ... | |
| self.model_path = model_path | |
| self.model_base = model_base | |
| print(f"[LlavaInterface] model_path={model_path}, model_base={model_base}") | |
| def inference( | |
| self, | |
| query: str, | |
| frames: Optional[List[Image.Image]] = None, | |
| system_message: str = "You are a helpful assistant.", | |
| temperature: float = 0.2, | |
| top_p: Optional[float] = None, | |
| num_beams: int = 1, | |
| max_tokens: int = 512, | |
| **kwargs | |
| ) -> str: | |
| """ | |
| 对外暴露统一的推理接口。 | |
| query: 用户输入,可能包含文本+<image>标记 | |
| frames: 对应的图像帧列表 | |
| system_message: 系统提示 | |
| 其它参数根据需要自行添加 | |
| """ | |
| # 模拟推理逻辑,需要你自己实现 | |
| print("[LlavaInterface] Inference called with query:", query) | |
| print("[LlavaInterface] frames count:", len(frames) if frames else 0) | |
| # 真实场景下,你会调用 Llava 模型进行推理 | |
| return "Fake Response from LlavaInterface" | |
| class GPT4Interface: | |
| def __init__(self,model="gpt-4o", api_key=None): | |
| """ | |
| Initialize the GPT-4 API client. | |
| Reads the OpenAI API key from the environment variable `OPENAI_API_KEY`. | |
| """ | |
| self.api_key = api_key | |
| self.model_name = model | |
| if api_key==None: | |
| self.api_key = os.getenv("OPENAI_API_KEY") | |
| if not self.api_key: | |
| raise ValueError("Environment variable OPENAI_API_KEY is not set.") | |
| openai.api_key = self.api_key | |
| def inference_text_only(self, query: str, system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str: | |
| """ | |
| Perform inference using the GPT-4 API. | |
| Args: | |
| query (str): User's query or input. | |
| system_message (str): System message to guide the model's behavior. | |
| temperature (float): Sampling temperature for the response. | |
| max_tokens (int): Maximum number of tokens for the response. | |
| Returns: | |
| str: The response generated by the GPT-4 model. | |
| """ | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": query}, | |
| ] | |
| try: | |
| response = openai.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def inference_with_frames(self, query: str, frames: List[Image.Image], system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str: | |
| """ | |
| Perform inference using the GPT-4 API with video frames as context. | |
| Args: | |
| query (str): User's query or input. | |
| frames (List[Image.Image]): List of PIL.Image objects to provide visual context. | |
| system_message (str): System message to guide the model's behavior. | |
| temperature (float): Sampling temperature for the response. | |
| max_tokens (int): Maximum number of tokens for the response. | |
| Returns: | |
| str: The response generated by the GPT-4 model. | |
| """ | |
| # Messages format | |
| inputs = [{"type": "text", "text": query}] | |
| # Encode frames as Base64 strings | |
| for i, frame in enumerate(frames): | |
| try: | |
| # Convert PIL Image to Base64 string | |
| frame_base64 = encode_image_to_base64(frame) | |
| visual_context = { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{frame_base64}", | |
| "detail": "low" | |
| } | |
| } | |
| # Adding visual context (images) to messages if supported by the model | |
| inputs.append(visual_context) | |
| except Exception as e: | |
| return f"Error encoding frame {i}: {str(e)}" | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": inputs}, | |
| ] | |
| try: | |
| response = openai.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def inference_qa(self, question: str, options: str, frames: List[Image.Image] = None, system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 500) -> str: | |
| """ | |
| Perform inference for a multiple-choice question with optional visual frames as context. | |
| Args: | |
| question (str): The question to answer. | |
| options (str): Multiple-choice options formatted as a string. | |
| frames (List[Image.Image], optional): List of PIL.Image objects to provide additional visual context. | |
| system_message (str): System message to guide the model's behavior. | |
| temperature (float): Sampling temperature for the response. | |
| max_tokens (int): Maximum number of tokens for the response. | |
| Returns: | |
| str: The selected option or answer. | |
| """ | |
| # Construct query | |
| query = f"Question: {question}\nOptions: {options}\nAnswer with the letter corresponding to the best choice." | |
| # Messages format | |
| inputs = [{"type": "text", "text": query}] | |
| if frames: | |
| # Encode frames as Base64 strings | |
| for i, frame in enumerate(frames): | |
| try: | |
| frame_base64 = encode_image_to_base64(frame) | |
| visual_context = { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{frame_base64}", | |
| "detail": "low" | |
| } | |
| } | |
| # Adding visual context (images) to messages if supported by the model | |
| inputs.append(visual_context) | |
| except Exception as e: | |
| return f"Error encoding frame {i}: {str(e)}" | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": inputs}, | |
| ] | |
| try: | |
| response = openai.chat.completions.create( | |
| model="gpt-4o", | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def inference_with_frames_all_in_one(self, query: str, frames: List[Image.Image], system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str: | |
| """ | |
| Perform inference using the GPT-4 API with video frames as context. | |
| Args: | |
| query (str): User's query or input. image tag: <image> | |
| frames (List[Image.Image]): List of PIL.Image objects to provide visual context. | |
| system_message (str): System message to guide the model's behavior. | |
| temperature (float): Sampling temperature for the response. | |
| max_tokens (int): Maximum number of tokens for the response. | |
| Returns: | |
| str: The response generated by the GPT-4 model. | |
| """ | |
| # Split query by <image> | |
| parts = query.split("<image>") | |
| inputs = [] | |
| # Add text and images alternately to inputs | |
| for i, part in enumerate(parts): | |
| if part.strip(): | |
| inputs.append({"type": "text", "text": part.strip()}) | |
| if i < len(frames): # Ensure we don't exceed the number of available frames | |
| try: | |
| frame_base64 = encode_image_to_base64(frames[i]) | |
| visual_context = { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{frame_base64}", | |
| "detail": "low" | |
| } | |
| } | |
| inputs.append(visual_context) | |
| except Exception as e: | |
| return f"Error encoding frame {i}: {str(e)}" | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": inputs}, | |
| ] | |
| try: | |
| response = openai.chat.completions.create( | |
| model=self.model_name, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| class TStarUniversalGrounder: | |
| """ | |
| 结合了原先 TStarGrounder 与 TStarGPTGrounder 的功能, | |
| 可以通过 backend 参数切换到底层使用的是 LlavaInterface 还是 GPT4Interface。 | |
| """ | |
| def __init__( | |
| self, | |
| backend: str = "gpt", | |
| model_name: str = "gpt-4o", | |
| model_path: Optional[str] = None, | |
| model_base: Optional[str] = None, | |
| gpt4_api_key: Optional[str] = None, | |
| num_frames: Optional[int] = 8, | |
| ): | |
| """ | |
| backend: "llava" 或 "gpt4" | |
| model_path, model_base: Llava 模型的路径及版本 | |
| gpt4_model_name, gpt4_api_key: GPT4 的模型名称及 API Key | |
| """ | |
| self.backend = backend.lower() | |
| self.num_frames = num_frames | |
| if self.backend == "llava": | |
| # 初始化 LlavaInterface | |
| if not model_path: | |
| raise ValueError("Please provide model_path for LlavaInterface") | |
| self.VLM_model_interfance = LlavaInterface(model_path=model_path, model_base=model_base) | |
| elif self.backend == "gpt4": | |
| # 初始化 GPT4Interface | |
| self.VLM_model_interfance = GPT4Interface(model=model_name, api_key=gpt4_api_key) | |
| else: | |
| raise ValueError("backend must be either 'llava' or 'gpt4'.") | |
| def inference_query_grounding( | |
| self, | |
| video_path: str, | |
| question: str, | |
| options: Optional[str] = None, | |
| temperature: float = 0.0, | |
| max_tokens: int = 512 | |
| ) -> Dict[str, List[str]]: | |
| """ | |
| 识别可作为答案依据的 target_objects 和可能辅助判断的 cue_objects。 | |
| """ | |
| frames = load_video_frames(video_path=video_path, num_frames=self.num_frames) | |
| # 构建 prompt | |
| system_prompt = ( | |
| "Here is a video:\n" | |
| + "\n".join(["<image>"] * len(frames)) | |
| + "\nHere is a question about the video:\n" | |
| f"Question: {question}\n" | |
| ) | |
| if options: | |
| system_prompt += f"Options: {options}\n" | |
| system_prompt += ( | |
| "\nWhen answering this question about the video:\n" | |
| "1. What key objects to locate the answer?\n" | |
| " - List potential key objects (short sentences, separated by commas).\n" | |
| "2. What cue objects might be near the key objects and might appear in the scenes?\n" | |
| " - List potential cue objects (short sentences, separated by commas).\n\n" | |
| "Please provide your answer in two lines, directly listing the key and cue objects, separated by commas." | |
| ) | |
| # 统一走 self.interface.inference # need more abstract function | |
| response = self.VLM_model_interfance.inference_with_frames_all_in_one( | |
| query=system_prompt, | |
| frames=frames, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| # 根据预期格式解析响应 | |
| lines = response.split("\n") | |
| if len(lines) < 2: | |
| # print(response) | |
| raise ValueError(f"Unexpected response format from inference_query_grounding() --> {response}.") | |
| target_objects = [self.check_objects_str(obj) for obj in lines[0].split(",") if obj.strip()] | |
| cue_objects = [self.check_objects_str(obj) for obj in lines[1].split(",") if obj.strip()] | |
| return target_objects, cue_objects | |
| def check_objects_str(self, obj: str): | |
| obj = obj.lower() #小写 | |
| obj = obj.strip().replace("1. ", "") | |
| obj = obj.strip().replace("2. ", "") | |
| obj = obj.strip().replace(".", "") | |
| obj = obj.strip().replace("key objects: ", "") | |
| obj = obj.strip().replace("cue objects: ", "") | |
| obj = obj.strip().replace(": ", "") | |
| return obj | |
| def inference_qa( | |
| self, | |
| frames: List[Image.Image], | |
| question: str, | |
| options: str, | |
| temperature: float = 0.2, | |
| max_tokens: int = 128 | |
| ) -> str: | |
| """ | |
| 多选推理,返回最可能的选项(如 A、B、C、D)。 | |
| """ | |
| system_prompt = ( | |
| "Select the best answer to the following multiple-choice question based on the video.\n" | |
| + "\n".join(["<image>"] * len(frames)) | |
| + f"\nQuestion: {question}\n" | |
| + f"Options: {options}\n\n" | |
| "Answer with the option’s letter from the given choices directly." | |
| ) | |
| response = self.VLM_model_interfance.inference_with_frames_all_in_one( | |
| query=system_prompt, | |
| frames=frames, | |
| temperature=temperature, | |
| max_tokens=30 | |
| ) | |
| return response.strip() | |
| def inference_openend_qa( | |
| self, | |
| frames: List[Image.Image], | |
| question: str, | |
| # options: str, | |
| temperature: float = 0.2, | |
| max_tokens: int = 2048 | |
| ) -> str: | |
| """ | |
| 多选推理,返回最可能的选项(如 A、B、C、D)。 | |
| """ | |
| system_prompt = ( | |
| "Answer with the question in short based on the video.\n" | |
| + "\n".join(["<image>"] * len(frames)) | |
| + f"\nQuestion: {question}\n" | |
| ) | |
| response = self.VLM_model_interfance.inference_with_frames_all_in_one( | |
| query=system_prompt, | |
| frames=frames, | |
| temperature=temperature, | |
| max_tokens=30 | |
| ) | |
| return response.strip() | |
| if __name__ == "__main__": | |
| """ | |
| 测试示例。 | |
| """ | |
| # 1) 使用 Llava 作为底层模型 | |
| # print("=== Using Llava backend ===") | |
| # llava_grounder = TStarUniversalGrounder( | |
| # backend="llava", | |
| # model_path="/path/to/llava", | |
| # model_base="v1.0" | |
| # ) | |
| frames_fake = [ # 随记噪声更好 | |
| Image.open("./output_image.jpg"), | |
| Image.open("/home/yejinhui/Projects/VisualSearch/output_image.jpg") | |
| ] | |
| # result_grounding_llava = llava_grounder.inference_query_grounding( | |
| # frames=frames_fake, | |
| # question="What objects are in the video?", | |
| # ) | |
| # print("Llava Grounding Result:", result_grounding_llava) | |
| # 2) 使用 GPT-4 作为底层模型 | |
| print("\n=== Using GPT-4 backend ===") | |
| gpt4_grounder = TStarUniversalGrounder( | |
| backend="gpt4", | |
| model_name="gpt-4o", | |
| gpt4_api_key=None | |
| ) | |
| searchable_objects = gpt4_grounder.inference_query_grounding( | |
| frames=frames_fake, | |
| question="What objects are in the video?" | |
| ) | |
| print("GPT-4 Grounding Result:", searchable_objects) | |
| # 3) 多选问答示例 | |
| question_mc = "How many cats can be seen?\n" | |
| options_mc = "A) 0\nB) 1\nC) 2\nD) 3\n" | |
| # answer_llava = llava_grounder.inference_qa(frames_fake, question_mc, options_mc) | |
| # print("Llava QA Answer:", answer_llava) | |
| answer_gpt4 = gpt4_grounder.inference_qa(frames_fake, question_mc, options_mc) | |
| print("GPT-4 QA Answer:", answer_gpt4) | |