TSTAR / TStar /interface_llm.py
ZihanWang314's picture
Upload folder using huggingface_hub
d686824 verified
import torch
import os
from tqdm import tqdm
from typing import Dict, Optional, Sequence, List
import transformers
import re
import openai
from typing import List, Dict
from PIL import Image
import base64
import io
try:
import cv2
except ImportError:
cv2 = None
print("Warning: OpenCV is not installed, video frame extraction will not work.")
from TStar.utils import *
class LlavaInterface:
"""
示例:封装对 Llava 模型的推理调用。
关键在于对外暴露统一的方法 inference(query, frames, **kwargs)。
"""
def __init__(self, model_path: str, model_base: Optional[str] = None):
# 这里是加载 Llava 模型等的逻辑
# self.tokenizer, self.model = ...
self.model_path = model_path
self.model_base = model_base
print(f"[LlavaInterface] model_path={model_path}, model_base={model_base}")
def inference(
self,
query: str,
frames: Optional[List[Image.Image]] = None,
system_message: str = "You are a helpful assistant.",
temperature: float = 0.2,
top_p: Optional[float] = None,
num_beams: int = 1,
max_tokens: int = 512,
**kwargs
) -> str:
"""
对外暴露统一的推理接口。
query: 用户输入,可能包含文本+<image>标记
frames: 对应的图像帧列表
system_message: 系统提示
其它参数根据需要自行添加
"""
# 模拟推理逻辑,需要你自己实现
print("[LlavaInterface] Inference called with query:", query)
print("[LlavaInterface] frames count:", len(frames) if frames else 0)
# 真实场景下,你会调用 Llava 模型进行推理
return "Fake Response from LlavaInterface"
class GPT4Interface:
def __init__(self,model="gpt-4o", api_key=None):
"""
Initialize the GPT-4 API client.
Reads the OpenAI API key from the environment variable `OPENAI_API_KEY`.
"""
self.api_key = api_key
self.model_name = model
if api_key==None:
self.api_key = os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("Environment variable OPENAI_API_KEY is not set.")
openai.api_key = self.api_key
def inference_text_only(self, query: str, system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str:
"""
Perform inference using the GPT-4 API.
Args:
query (str): User's query or input.
system_message (str): System message to guide the model's behavior.
temperature (float): Sampling temperature for the response.
max_tokens (int): Maximum number of tokens for the response.
Returns:
str: The response generated by the GPT-4 model.
"""
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": query},
]
try:
response = openai.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error: {str(e)}"
def inference_with_frames(self, query: str, frames: List[Image.Image], system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str:
"""
Perform inference using the GPT-4 API with video frames as context.
Args:
query (str): User's query or input.
frames (List[Image.Image]): List of PIL.Image objects to provide visual context.
system_message (str): System message to guide the model's behavior.
temperature (float): Sampling temperature for the response.
max_tokens (int): Maximum number of tokens for the response.
Returns:
str: The response generated by the GPT-4 model.
"""
# Messages format
inputs = [{"type": "text", "text": query}]
# Encode frames as Base64 strings
for i, frame in enumerate(frames):
try:
# Convert PIL Image to Base64 string
frame_base64 = encode_image_to_base64(frame)
visual_context = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{frame_base64}",
"detail": "low"
}
}
# Adding visual context (images) to messages if supported by the model
inputs.append(visual_context)
except Exception as e:
return f"Error encoding frame {i}: {str(e)}"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": inputs},
]
try:
response = openai.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error: {str(e)}"
def inference_qa(self, question: str, options: str, frames: List[Image.Image] = None, system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 500) -> str:
"""
Perform inference for a multiple-choice question with optional visual frames as context.
Args:
question (str): The question to answer.
options (str): Multiple-choice options formatted as a string.
frames (List[Image.Image], optional): List of PIL.Image objects to provide additional visual context.
system_message (str): System message to guide the model's behavior.
temperature (float): Sampling temperature for the response.
max_tokens (int): Maximum number of tokens for the response.
Returns:
str: The selected option or answer.
"""
# Construct query
query = f"Question: {question}\nOptions: {options}\nAnswer with the letter corresponding to the best choice."
# Messages format
inputs = [{"type": "text", "text": query}]
if frames:
# Encode frames as Base64 strings
for i, frame in enumerate(frames):
try:
frame_base64 = encode_image_to_base64(frame)
visual_context = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{frame_base64}",
"detail": "low"
}
}
# Adding visual context (images) to messages if supported by the model
inputs.append(visual_context)
except Exception as e:
return f"Error encoding frame {i}: {str(e)}"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": inputs},
]
try:
response = openai.chat.completions.create(
model="gpt-4o",
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error: {str(e)}"
def inference_with_frames_all_in_one(self, query: str, frames: List[Image.Image], system_message: str = "You are a helpful assistant.", temperature: float = 0.7, max_tokens: int = 1000) -> str:
"""
Perform inference using the GPT-4 API with video frames as context.
Args:
query (str): User's query or input. image tag: <image>
frames (List[Image.Image]): List of PIL.Image objects to provide visual context.
system_message (str): System message to guide the model's behavior.
temperature (float): Sampling temperature for the response.
max_tokens (int): Maximum number of tokens for the response.
Returns:
str: The response generated by the GPT-4 model.
"""
# Split query by <image>
parts = query.split("<image>")
inputs = []
# Add text and images alternately to inputs
for i, part in enumerate(parts):
if part.strip():
inputs.append({"type": "text", "text": part.strip()})
if i < len(frames): # Ensure we don't exceed the number of available frames
try:
frame_base64 = encode_image_to_base64(frames[i])
visual_context = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{frame_base64}",
"detail": "low"
}
}
inputs.append(visual_context)
except Exception as e:
return f"Error encoding frame {i}: {str(e)}"
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": inputs},
]
try:
response = openai.chat.completions.create(
model=self.model_name,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error: {str(e)}"
class TStarUniversalGrounder:
"""
结合了原先 TStarGrounder 与 TStarGPTGrounder 的功能,
可以通过 backend 参数切换到底层使用的是 LlavaInterface 还是 GPT4Interface。
"""
def __init__(
self,
backend: str = "gpt",
model_name: str = "gpt-4o",
model_path: Optional[str] = None,
model_base: Optional[str] = None,
gpt4_api_key: Optional[str] = None,
num_frames: Optional[int] = 8,
):
"""
backend: "llava" 或 "gpt4"
model_path, model_base: Llava 模型的路径及版本
gpt4_model_name, gpt4_api_key: GPT4 的模型名称及 API Key
"""
self.backend = backend.lower()
self.num_frames = num_frames
if self.backend == "llava":
# 初始化 LlavaInterface
if not model_path:
raise ValueError("Please provide model_path for LlavaInterface")
self.VLM_model_interfance = LlavaInterface(model_path=model_path, model_base=model_base)
elif self.backend == "gpt4":
# 初始化 GPT4Interface
self.VLM_model_interfance = GPT4Interface(model=model_name, api_key=gpt4_api_key)
else:
raise ValueError("backend must be either 'llava' or 'gpt4'.")
def inference_query_grounding(
self,
video_path: str,
question: str,
options: Optional[str] = None,
temperature: float = 0.0,
max_tokens: int = 512
) -> Dict[str, List[str]]:
"""
识别可作为答案依据的 target_objects 和可能辅助判断的 cue_objects。
"""
frames = load_video_frames(video_path=video_path, num_frames=self.num_frames)
# 构建 prompt
system_prompt = (
"Here is a video:\n"
+ "\n".join(["<image>"] * len(frames))
+ "\nHere is a question about the video:\n"
f"Question: {question}\n"
)
if options:
system_prompt += f"Options: {options}\n"
system_prompt += (
"\nWhen answering this question about the video:\n"
"1. What key objects to locate the answer?\n"
" - List potential key objects (short sentences, separated by commas).\n"
"2. What cue objects might be near the key objects and might appear in the scenes?\n"
" - List potential cue objects (short sentences, separated by commas).\n\n"
"Please provide your answer in two lines, directly listing the key and cue objects, separated by commas."
)
# 统一走 self.interface.inference # need more abstract function
response = self.VLM_model_interfance.inference_with_frames_all_in_one(
query=system_prompt,
frames=frames,
temperature=temperature,
max_tokens=max_tokens,
)
# 根据预期格式解析响应
lines = response.split("\n")
if len(lines) < 2:
# print(response)
raise ValueError(f"Unexpected response format from inference_query_grounding() --> {response}.")
target_objects = [self.check_objects_str(obj) for obj in lines[0].split(",") if obj.strip()]
cue_objects = [self.check_objects_str(obj) for obj in lines[1].split(",") if obj.strip()]
return target_objects, cue_objects
def check_objects_str(self, obj: str):
obj = obj.lower() #小写
obj = obj.strip().replace("1. ", "")
obj = obj.strip().replace("2. ", "")
obj = obj.strip().replace(".", "")
obj = obj.strip().replace("key objects: ", "")
obj = obj.strip().replace("cue objects: ", "")
obj = obj.strip().replace(": ", "")
return obj
def inference_qa(
self,
frames: List[Image.Image],
question: str,
options: str,
temperature: float = 0.2,
max_tokens: int = 128
) -> str:
"""
多选推理,返回最可能的选项(如 A、B、C、D)。
"""
system_prompt = (
"Select the best answer to the following multiple-choice question based on the video.\n"
+ "\n".join(["<image>"] * len(frames))
+ f"\nQuestion: {question}\n"
+ f"Options: {options}\n\n"
"Answer with the option’s letter from the given choices directly."
)
response = self.VLM_model_interfance.inference_with_frames_all_in_one(
query=system_prompt,
frames=frames,
temperature=temperature,
max_tokens=30
)
return response.strip()
def inference_openend_qa(
self,
frames: List[Image.Image],
question: str,
# options: str,
temperature: float = 0.2,
max_tokens: int = 2048
) -> str:
"""
多选推理,返回最可能的选项(如 A、B、C、D)。
"""
system_prompt = (
"Answer with the question in short based on the video.\n"
+ "\n".join(["<image>"] * len(frames))
+ f"\nQuestion: {question}\n"
)
response = self.VLM_model_interfance.inference_with_frames_all_in_one(
query=system_prompt,
frames=frames,
temperature=temperature,
max_tokens=30
)
return response.strip()
if __name__ == "__main__":
"""
测试示例。
"""
# 1) 使用 Llava 作为底层模型
# print("=== Using Llava backend ===")
# llava_grounder = TStarUniversalGrounder(
# backend="llava",
# model_path="/path/to/llava",
# model_base="v1.0"
# )
frames_fake = [ # 随记噪声更好
Image.open("./output_image.jpg"),
Image.open("/home/yejinhui/Projects/VisualSearch/output_image.jpg")
]
# result_grounding_llava = llava_grounder.inference_query_grounding(
# frames=frames_fake,
# question="What objects are in the video?",
# )
# print("Llava Grounding Result:", result_grounding_llava)
# 2) 使用 GPT-4 作为底层模型
print("\n=== Using GPT-4 backend ===")
gpt4_grounder = TStarUniversalGrounder(
backend="gpt4",
model_name="gpt-4o",
gpt4_api_key=None
)
searchable_objects = gpt4_grounder.inference_query_grounding(
frames=frames_fake,
question="What objects are in the video?"
)
print("GPT-4 Grounding Result:", searchable_objects)
# 3) 多选问答示例
question_mc = "How many cats can be seen?\n"
options_mc = "A) 0\nB) 1\nC) 2\nD) 3\n"
# answer_llava = llava_grounder.inference_qa(frames_fake, question_mc, options_mc)
# print("Llava QA Answer:", answer_llava)
answer_gpt4 = gpt4_grounder.inference_qa(frames_fake, question_mc, options_mc)
print("GPT-4 QA Answer:", answer_gpt4)