Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| # =========================================================================================== | |
| # | |
| # Copyright (c) Beijing Academy of Artificial Intelligence (BAAI). All rights reserved. | |
| # | |
| # Author : Fan Zhang | |
| # Email : zhangfan@baai.ac.cn | |
| # Institute : Beijing Academy of Artificial Intelligence (BAAI) | |
| # Create On : 2023-12-12 02:54 | |
| # Last Modified : 2023-12-21 04:17 | |
| # File Name : meta.py | |
| # Description : | |
| # | |
| # =========================================================================================== | |
| import base64 | |
| from dataclasses import dataclass | |
| import io | |
| from enum import Enum | |
| from PIL import Image | |
| from typing import List, Tuple | |
| import cv2 | |
| import numpy as np | |
| from .constants import EVA_IMAGE_SIZE, GRD_SYMBOL, BOP_SYMBOL, EOP_SYMBOL, BOO_SYMBOL, EOO_SYMBOL | |
| from .constants import DEFAULT_VIDEO_TOKEN, DEFAULT_EOS_TOKEN, USER_TOKEN, ASSISTANT_TOKEN, FAKE_VIDEO_END_TOKEN | |
| from .utils import gen_id, frontend_logger as logging | |
| class Role(Enum): | |
| UNKNOWN = 0, | |
| USER = 1, | |
| ASSISTANT = 2, | |
| class DataType(Enum): | |
| UNKNOWN = 0, | |
| TEXT = 1, | |
| IMAGE = 2, | |
| GROUNDING = 3, | |
| VIDEO = 4, | |
| ERROR = 5, | |
| class DataMeta: | |
| datatype: DataType = DataType.UNKNOWN | |
| text: str = None | |
| image: Image.Image = None | |
| mask: Image.Image = None | |
| coordinate: List[int] = None | |
| frames: List[Image.Image] = None | |
| stack_frame: Image.Image = None | |
| def grounding(self): | |
| return self.coordinate is not None | |
| def text_str(self): | |
| return self.text | |
| def image_str(self): | |
| return self.image2str(self.image) | |
| def video_str(self): | |
| ret = f'<div style="overflow:scroll"><b>[VIDEO]</b></div>{self.image2str(self.stack_frame)}' | |
| return ret | |
| def grounding_str(self): | |
| ret = "" | |
| if self.text is not None: | |
| ret += f'<div style="overflow:scroll"><b>[PHRASE]</b>{self.text}</div>' | |
| ret += self.image2str(self.mask) | |
| if self.image is not None: | |
| ret += self.image2str(self.image) | |
| return ret | |
| def image2str(self, image): | |
| buf = io.BytesIO() | |
| image.save(buf, format="WEBP") | |
| i_str = base64.b64encode(buf.getvalue()).decode() | |
| return f'<div style="float:left"><img src="data:image/png;base64, {i_str}"></div>' | |
| def format_chatbot(self): | |
| match self.datatype: | |
| case DataType.TEXT | DataType.ERROR: | |
| return self.text_str | |
| case DataType.IMAGE: | |
| return self.image_str | |
| case DataType.VIDEO: | |
| return self.video_str | |
| case DataType.GROUNDING: | |
| return self.grounding_str | |
| case _: | |
| return "" | |
| def format_prompt(self) -> List[str | Image.Image]: | |
| match self.datatype: | |
| case DataType.TEXT: | |
| return [self.text] | |
| case DataType.IMAGE: | |
| return [self.image] | |
| case DataType.VIDEO: | |
| return [DEFAULT_VIDEO_TOKEN] + self.frames + [FAKE_VIDEO_END_TOKEN] | |
| case DataType.GROUNDING: | |
| ret = [] | |
| if self.text is not None: | |
| ret.append(f"{BOP_SYMBOL}{self.text}{EOP_SYMBOL}") | |
| ret += [BOO_SYMBOL, self.mask, EOO_SYMBOL] | |
| if self.image is not None: | |
| ret.append(self.image) | |
| return ret | |
| case _: | |
| return [] | |
| def __str__(self): | |
| s = "" | |
| if self.text is not None: | |
| s += f"T:{self.text}" | |
| if self.image is not None: | |
| w, h = self.image.size | |
| s += f"[I:{h}x{w}]" | |
| if self.coordinate is not None: | |
| l, t, r, b = self.coordinate | |
| s += f"[C:({l:03d},{t:03d}),({r:03d},{b:03d})]" | |
| if self.frames is not None: | |
| w, h = self.frames[0].size | |
| s += f"[V:{len(self.frames)}x{h}x{w}]" | |
| return s | |
| def build(cls, text=None, image=None, coordinate=None, frames=None, is_error=False, *, resize: bool = True): | |
| ins = cls() | |
| ins.text = text if text != "" else None | |
| ins.image = cls.resize(image, force=resize) | |
| # ins.image = image | |
| ins.coordinate = cls.fix(coordinate) | |
| ins.frames = cls.resize(frames, force=resize) | |
| # ins.frames = frames | |
| if is_error: | |
| ins.datatype = DataType.ERROR | |
| elif coordinate is not None: | |
| ins.datatype = DataType.GROUNDING | |
| ins.draw_box() | |
| elif image is not None: | |
| ins.datatype = DataType.IMAGE | |
| elif text is not None: | |
| ins.datatype = DataType.TEXT | |
| else: | |
| ins.datatype = DataType.VIDEO | |
| ins.stack() | |
| return ins | |
| def fix(cls, coordinate): | |
| if coordinate is None: | |
| return None | |
| l, t, r, b = coordinate | |
| l = min(EVA_IMAGE_SIZE, max(0, l)) | |
| t = min(EVA_IMAGE_SIZE, max(0, t)) | |
| r = min(EVA_IMAGE_SIZE, max(0, r)) | |
| b = min(EVA_IMAGE_SIZE, max(0, b)) | |
| return min(l, r), min(t, b), max(l, r), max(t, b) | |
| def resize(cls, image: Image.Image | List[Image.Image] | None, *, force: bool = True): | |
| if image is None: | |
| return None | |
| if not force: | |
| return image | |
| if isinstance(image, Image.Image): | |
| image = [image] | |
| for idx, im in enumerate(image): | |
| w, h = im.size | |
| if w < EVA_IMAGE_SIZE or h < EVA_IMAGE_SIZE: | |
| continue | |
| if w < h: | |
| h = int(EVA_IMAGE_SIZE / w * h) | |
| w = EVA_IMAGE_SIZE | |
| else: | |
| w = int(EVA_IMAGE_SIZE / h * w) | |
| h = EVA_IMAGE_SIZE | |
| image[idx] = im.resize((w, h)) | |
| return image if len(image) > 1 else image[0] | |
| def draw_box(self): | |
| left, top, right, bottom = self.coordinate | |
| mask = np.zeros((EVA_IMAGE_SIZE, EVA_IMAGE_SIZE, 3), dtype=np.uint8) | |
| mask = cv2.rectangle(mask, (left, top), (right, bottom), (255, 255, 255), 3) | |
| self.mask = Image.fromarray(mask) | |
| def stack(self): | |
| w, h = self.frames[0].size | |
| n = len(self.frames) | |
| stack_frame = Image.new(mode="RGB", size=(w*n, h)) | |
| for idx, f in enumerate(self.frames): | |
| stack_frame.paste(f, (idx*w, 0)) | |
| self.stack_frame = stack_frame | |
| class ConvMeta: | |
| def __init__(self): | |
| self.system: str = "You are a helpful assistant, dedicated to delivering comprehensive and meticulous responses." | |
| self.message: List[Tuple[Role, DataMeta]] = [] | |
| self.log_id: str = gen_id() | |
| logging.info(f"{self.log_id}: create new round of chat") | |
| def append(self, r: Role, p: DataMeta): | |
| logging.info(f"{self.log_id}: APPEND [{r.name}] prompt element, type: {p.datatype.name}, message: {p}") | |
| self.message.append((r, p)) | |
| def format_chatbot(self): | |
| ret = [] | |
| for r, p in self.message: | |
| cur_p = p.format_chatbot() | |
| if r == Role.USER: | |
| ret.append((cur_p, None)) | |
| else: | |
| ret.append((None, cur_p)) | |
| return ret | |
| def format_prompt(self): | |
| ret = [] | |
| has_coor = False | |
| for _, p in self.message: | |
| has_coor |= (p.datatype == DataType.GROUNDING) | |
| ret += p.format_prompt() | |
| if has_coor: | |
| ret.insert(0, GRD_SYMBOL) | |
| logging.info(f"{self.log_id}: format generation prompt: {ret}") | |
| return ret | |
| def format_chat(self): | |
| ret = [self.system] | |
| prev_r = None | |
| for r, p in self.message: | |
| if prev_r != r: | |
| if prev_r == Role.ASSISTANT: | |
| ret.append(f"{DEFAULT_EOS_TOKEN}{USER_TOKEN}: ") | |
| elif prev_r is None: | |
| ret.append(f" {USER_TOKEN}: ") | |
| else: | |
| ret.append(f" {ASSISTANT_TOKEN}: ") | |
| ret += p.format_prompt() | |
| prev_r = r | |
| else: | |
| ret += p.format_prompt() | |
| ret.append(f" {ASSISTANT_TOKEN}:") | |
| logging.info(f"{self.log_id}: format chat prompt: {ret}") | |
| return ret | |
| def clear(self): | |
| logging.info(f"{self.log_id}: clear chat history, end current chat round.") | |
| del self.message | |
| self.message = [] | |
| self.log_id = gen_id() | |
| def pop(self): | |
| if self.has_gen: | |
| logging.info(f"{self.log_id}: pop out previous generation / chat result") | |
| self.message.pop() | |
| def pop_error(self): | |
| new_message = [] | |
| for r, p in self.message: | |
| if p.datatype == DataType.ERROR: | |
| logging.info(f"{self.log_id}: pop error message: {p.text_str}") | |
| else: | |
| new_message.append((r, p)) | |
| del self.message | |
| self.message = new_message | |
| def has_gen(self): | |
| if len(self.message) == 0: | |
| return False | |
| if self.message[-1][0] == Role.USER: | |
| return False | |
| return True | |
| def __len__(self): | |
| return len(self.message) | |