| | import os |
| | import base64 |
| | import tempfile |
| | from inference import Chat, get_conv_template |
| | import torch |
| |
|
| | def save_base64_to_tempfile(base64_str, suffix): |
| | header_removed = base64_str |
| | |
| | if ',' in base64_str: |
| | header_removed = base64_str.split(',', 1)[1] |
| |
|
| | data = base64.b64decode(header_removed) |
| | tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) |
| | tmp.write(data) |
| | tmp.close() |
| | return tmp.name |
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_path: str): |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.chat = Chat( |
| | model_path=model_path, |
| | device=device, |
| | num_gpus=1, |
| | max_new_tokens=1024, |
| | load_8bit=False, |
| | ) |
| | self.vision_feature = None |
| | self.modal_type = "text" |
| | self.chat.conv = get_conv_template("husky").copy() |
| |
|
| | def __call__(self, data: dict) -> dict: |
| | |
| | if data.get("clear_history"): |
| | self.chat.conv = get_conv_template("husky").copy() |
| | self.vision_feature = None |
| | self.modal_type = "text" |
| |
|
| | prompt = data.get("inputs", "") |
| | image_input = data.get("image", None) |
| | video_input = data.get("video", None) |
| |
|
| | print("📨 收到 prompt:", repr(prompt)) |
| | |
| | |
| | if image_input: |
| | if os.path.exists(image_input): |
| | |
| | self.vision_feature = self.chat.get_image_embedding(image_input) |
| | else: |
| | |
| | tmp_path = save_base64_to_tempfile(image_input, suffix=".jpg") |
| | self.vision_feature = self.chat.get_image_embedding(tmp_path) |
| | os.unlink(tmp_path) |
| | self.modal_type = "image" |
| | self.chat.conv = get_conv_template("husky").copy() |
| |
|
| | elif video_input: |
| | if os.path.exists(video_input): |
| | self.vision_feature = self.chat.get_video_embedding(video_input) |
| | else: |
| | tmp_path = save_base64_to_tempfile(video_input, suffix=".mp4") |
| | print("📼 保存临时视频路径:", tmp_path) |
| | self.vision_feature = self.chat.get_video_embedding(tmp_path) |
| | os.unlink(tmp_path) |
| | self.modal_type = "video" |
| | self.chat.conv = get_conv_template("husky").copy() |
| | |
| | |
| | if isinstance(self.vision_feature, torch.Tensor): |
| | print("📏 视觉特征张量 shape:", self.vision_feature.shape) |
| | else: |
| | print("❌ self.vision_feature 不是张量,类型:", type(self.vision_feature)) |
| |
|
| | else: |
| | self.modal_type = "text" |
| | self.vision_feature = None |
| |
|
| | try: |
| | |
| | print("🧠 当前 modal_type:", self.modal_type) |
| | print("🧠 是否有视觉特征:", self.vision_feature is not None) |
| |
|
| | conversations = self.chat.ask(prompt, self.chat.conv, modal_type=self.modal_type) |
| | output = self.chat.answer(conversations, self.vision_feature, modal_type=self.modal_type) |
| |
|
| | |
| | print("📤 推理输出:", repr(output.strip())) |
| |
|
| | self.chat.conv.messages[-1][1] = output.strip() |
| | return {"output": output.strip()} |
| |
|
| | except Exception as e: |
| | |
| | import traceback |
| | print("❌ 推理出错:") |
| | traceback.print_exc() |
| | return {"error": str(e)} |
| |
|