Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import datetime | |
| import json | |
| from typing import Optional | |
| import transformers | |
| from dataclasses import dataclass, field | |
| import io | |
| import spaces | |
| import base64 | |
| from PIL import Image | |
| import gradio as gr | |
| import time | |
| import hashlib | |
| from utils import build_logger | |
| from conversation import conv_seed_llama2 | |
| import hydra | |
| import pyrootutils | |
| import torch | |
| import re | |
| import time | |
| from omegaconf import OmegaConf | |
| from flask import Flask | |
| import json | |
| from typing import Optional | |
| import cv2 | |
| from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler, StableDiffusionImg2ImgPipeline | |
| pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
| BOI_TOKEN = '<img>' | |
| EOI_TOKEN = '</img>' | |
| IMG_TOKEN = '<img_{:05d}>' | |
| IMG_FLAG = '<image>' | |
| num_img_in_tokens = 64 | |
| num_img_out_tokens = 64 | |
| resolution_grids = ['1x1', '1x2', '1x3', '1x4', '1x5', '1x6', '1x10', '2x1', '3x1', '4x1', '5x1', '6x1', '10x1', '2x2', | |
| '2x3', '3x2', '2x4', '4x2'] | |
| base_resolution = 448 | |
| app = Flask(__name__) | |
| def decode_image(encoded_image: str) -> Image: | |
| decoded_bytes = base64.b64decode(encoded_image.encode('utf-8')) | |
| buffer = io.BytesIO(decoded_bytes) | |
| image = Image.open(buffer) | |
| return image | |
| def encode_image(image: Image.Image, format: str = 'PNG') -> str: | |
| with io.BytesIO() as buffer: | |
| image.save(buffer, format=format) | |
| encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| return encoded_image | |
| class Arguments: | |
| image_transform: Optional[str] = field(default='configs/processer/qwen_448_transform.yaml', | |
| metadata={"help": "config path of image transform"}) | |
| tokenizer: Optional[str] = field(default='configs/tokenizer/clm_llama_tokenizer.yaml', | |
| metadata={"help": "config path of tokenizer used to initialize tokenizer"}) | |
| llm: Optional[str] = field(default='configs/clm_models/llama2chat7b_lora.yaml', metadata={"help": "config path of llm"}) | |
| visual_encoder: Optional[str] = field(default='configs/visual_tokenizer/qwen_vitg_448.yaml', | |
| metadata={"help": "config path of visual encoder"}) | |
| sd_adapter: Optional[str] = field( | |
| default='configs/detokenizer/detokenizer_sdxl_qwen_vit_adapted.yaml', | |
| metadata={"help": "config path of sd adapter"}) | |
| agent: Optional[str] = field(default='configs/clm_models/agent_7b_sft.yaml', | |
| metadata={"help": "config path of agent model"}) | |
| diffusion_path: Optional[str] = field(default='stabilityai/stable-diffusion-xl-base-1.0', | |
| metadata={"help": "diffusion model path"}) | |
| port: Optional[str] = field(default=80, metadata={"help": "network port"}) | |
| llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"}) | |
| vit_sd_device: Optional[str] = field(default='cuda:0', metadata={"help": "sd and vit device"}) | |
| dtype: Optional[str] = field(default='fp16', metadata={"help": "mix percision"}) | |
| parser = transformers.HfArgumentParser(Arguments) | |
| args, = parser.parse_args_into_dataclasses() | |
| class LLMService: | |
| def __init__(self, args) -> None: | |
| self.llm_device = args.llm_device | |
| self.vit_sd_device = args.vit_sd_device | |
| dtype = args.dtype | |
| if dtype == 'fp16': | |
| self.dtype = torch.float16 | |
| elif dtype == 'bf16': | |
| self.dtype = torch.bfloat16 | |
| else: | |
| raise ValueError | |
| image_transform_cfg = OmegaConf.load(args.image_transform) | |
| self.image_transform = hydra.utils.instantiate(image_transform_cfg) | |
| tokenizer_cfg = OmegaConf.load(args.tokenizer) | |
| self.tokenizer = hydra.utils.instantiate(tokenizer_cfg) | |
| visual_encoder_cfg = OmegaConf.load(args.visual_encoder) | |
| self.visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) | |
| self.visual_encoder.eval().to(self.vit_sd_device, dtype=self.dtype) | |
| print('Init visual encoder done') | |
| llm_cfg = OmegaConf.load(args.llm) | |
| llm = hydra.utils.instantiate(llm_cfg, torch_dtype=self.dtype) | |
| print('Init llm done.') | |
| agent_cfg = OmegaConf.load(args.agent) | |
| self.agent = hydra.utils.instantiate(agent_cfg, llm=llm) | |
| self.agent.eval().to(self.llm_device, dtype=self.dtype) | |
| print('Init agent mdoel Done') | |
| noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.diffusion_path, subfolder="scheduler") | |
| vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device, | |
| dtype=self.dtype) | |
| unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(self.vit_sd_device, | |
| dtype=self.dtype) | |
| sd_adapter_cfg = OmegaConf.load(args.sd_adapter) | |
| self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(self.vit_sd_device, | |
| dtype=self.dtype) | |
| # self.sd_adapter.init_pipe(vae=vae, | |
| # scheduler=noise_scheduler, | |
| # visual_encoder=self.visual_encoder.cpu(), | |
| # image_transform=self.image_transform, | |
| # discrete_model=None, | |
| # dtype=self.dtype, | |
| # device="cpu") | |
| self.sd_adapter.init_pipe(vae=vae, | |
| scheduler=noise_scheduler, | |
| visual_encoder=self.visual_encoder, | |
| image_transform=self.image_transform, | |
| discrete_model=None, | |
| dtype=self.dtype, | |
| device=self.vit_sd_device) | |
| print('Init sd adapter pipe done.') | |
| self.visual_encoder.to(self.vit_sd_device, dtype=self.dtype) | |
| # model_id_or_path = "stablediffusionapi/realistic-vision-v51" | |
| # self.vae_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, safety_checker=None, | |
| # torch_dtype=torch.float16) | |
| self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] | |
| self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] | |
| service = LLMService(args) | |
| def generate(text_list, image_list, max_new_tokens): | |
| with torch.no_grad(): | |
| text_list = text_list.split(IMG_FLAG) | |
| top_p = 0.5 | |
| assert len(text_list) == len(image_list) + 1 | |
| image_tokens = BOI_TOKEN + ''.join( | |
| [IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN | |
| input_images = [] | |
| if len(image_list) > 0: | |
| image_tensor_list = [] | |
| embeds_cmp_mask = [] | |
| embeds_gen_mask = [] | |
| if service.multi_resolution: | |
| patch_pos = [] | |
| image_patch_length = [] | |
| image_size_list = [] | |
| for idx, image_item in enumerate(image_list): | |
| if isinstance(image_item, str): | |
| image = decode_image(image_item) | |
| print('after decode image size:', image.size) | |
| input_images.append(image) | |
| # if service.multi_resolution: | |
| # image_size_list.append(image.size) | |
| # print('image size:', image.size) | |
| # image_tensor, patch_pos_tensor = process_anyres_image(image, service.image_transform, | |
| # service.grid_pinpoints, | |
| # service.base_resolution) | |
| # image_tensor_list.append(image_tensor) | |
| # patch_pos.append(patch_pos_tensor) | |
| # image_patch_length.append(image_tensor.shape[0]) | |
| # print('image_patch_length', image_patch_length) | |
| # embeds_cmp_mask.extend([True] * image_tensor.shape[0]) | |
| # embeds_gen_mask.extend([False] * image_tensor.shape[0]) | |
| # | |
| # else: | |
| image_tensor = service.image_transform(image) | |
| image_tensor_list.append(image_tensor) | |
| embeds_cmp_mask.append(True) | |
| embeds_gen_mask.append(False) | |
| else: | |
| raise ValueError | |
| if service.multi_resolution: | |
| pixel_values = torch.cat(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) | |
| patch_position = torch.cat(patch_pos, dim=0) | |
| image_tokens_list = [] | |
| for patch_length in image_patch_length: | |
| image_tokens = '' | |
| for _ in range(patch_length - 1): | |
| image_tokens += BOP_TOKEN + ''.join( | |
| IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN | |
| image_tokens += BOI_TOKEN + ''.join( | |
| IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN | |
| image_tokens_list.append(image_tokens) | |
| else: | |
| pixel_values = torch.stack(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) | |
| image_embeds = service.visual_encoder(pixel_values) | |
| image_embeds = image_embeds.to(service.llm_device) | |
| embeds_cmp_mask = torch.tensor(embeds_cmp_mask, dtype=torch.bool).to(service.llm_device) | |
| embeds_gen_mask = torch.tensor(embeds_gen_mask, dtype=torch.bool).to(service.llm_device) | |
| else: | |
| image_embeds = None | |
| patch_position = 0 | |
| embeds_cmp_mask = None | |
| embeds_gen_mask = None | |
| input_text = image_tokens.join(text_list) | |
| print('input_text:', input_text) | |
| input_ids = service.tokenizer.encode(input_text, add_special_tokens=False) | |
| input_ids = [service.tokenizer.bos_token_id] + input_ids | |
| input_ids = torch.tensor(input_ids).to(service.llm_device, dtype=torch.long) | |
| ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) | |
| ids_gen_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) | |
| boi_indices = torch.where(input_ids == service.boi_token_id)[0].tolist() | |
| eoi_indices = torch.where(input_ids == service.eoi_token_id)[0].tolist() | |
| for boi_idx, eoi_idx in zip(boi_indices, eoi_indices): | |
| ids_cmp_mask[boi_idx + 1:eoi_idx] = True | |
| input_ids = input_ids.unsqueeze(0) | |
| ids_cmp_mask = ids_cmp_mask.unsqueeze(0) | |
| ids_gen_mask = ids_gen_mask.unsqueeze(0) | |
| error_msg = [] | |
| output = service.agent.generate( | |
| tokenizer=service.tokenizer, | |
| input_ids=input_ids, | |
| image_embeds=image_embeds, | |
| embeds_cmp_mask=embeds_cmp_mask, | |
| ids_cmp_mask=ids_cmp_mask, | |
| num_img_gen_tokens=num_img_out_tokens, | |
| max_new_tokens=max_new_tokens, | |
| dtype=service.dtype, | |
| device=service.llm_device, | |
| top_p=top_p, | |
| ) | |
| gen_imgs_base64_list = [] | |
| generated_text = output['text'] | |
| generated_text = generated_text.replace(EOI_TOKEN, IMG_FLAG).replace(service.tokenizer.eos_token, '') | |
| torch.cuda.empty_cache() | |
| if output['has_img_output']: | |
| # print('loading visual encoder and llm to CPU, and sd to GPU') | |
| # a = time.time() | |
| # service.agent = service.agent.cpu() | |
| # service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype) | |
| # print("Loading finished: ", time.time() - a) | |
| img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype) | |
| for img_idx in range(output['num_gen_imgs']): | |
| img_feat = img_gen_feat[img_idx:img_idx + 1] | |
| generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0] | |
| # a = time.time() | |
| # service.sd_adapter = service.sd_adapter.cpu() | |
| # service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype) | |
| # service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype) | |
| # print("Loading finished: ", time.time() - a) | |
| print(input_text + generated_text) | |
| return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg} | |
| def http_bot(dialog_state, input_state, max_new_tokens, max_turns, | |
| request: gr.Request): | |
| print('input_state:', input_state) | |
| if len(dialog_state.messages) == 0 or dialog_state.messages[-1]['role'] != dialog_state.roles[0] or len( | |
| dialog_state.messages[-1]['message']['text'].strip(' ?.;!/')) == 0: | |
| return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 | |
| if len(dialog_state.messages) > max_turns * 2: | |
| output_state = init_input_state() | |
| output_state['text'] = 'Error: History exceeds maximum rounds, please clear history and restart.' | |
| dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state}) | |
| input_state = init_input_state() | |
| return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 3 + (enable_btn,) | |
| prompt = dialog_state.get_prompt() | |
| text = prompt['text'] | |
| max_new_tokens = int(max_new_tokens) | |
| images = prompt['images'] | |
| results = generate(text, images, max_new_tokens) | |
| print('response: ', {'text': results['text'], 'error_msg': results['error_msg']}) | |
| output_state = init_input_state() | |
| image_dir = get_conv_image_dir() | |
| output_state['text'] = results['text'] | |
| for image_base64 in results['images']: | |
| if image_base64 == '': | |
| image_path = '' | |
| else: | |
| image = decode_image(image_base64) | |
| image = image.convert('RGB') | |
| image_path = get_image_name(image=image, image_dir=image_dir) | |
| if not os.path.exists(image_path): | |
| image.save(image_path) | |
| output_state['images'].append(image_path) | |
| dialog_state.messages.append({'role': dialog_state.roles[1], 'message': output_state}) | |
| vote_last_response(dialog_state, 'common', request) | |
| input_state = init_input_state() | |
| chatbot = update_error_msg(dialog_state.to_gradio_chatbot(), results['error_msg']) | |
| return (dialog_state, input_state, chatbot) + (enable_btn,) * 4 | |
| IMG_FLAG = '<image>' | |
| LOGDIR = 'log' | |
| logger = build_logger("gradio_seed_x", LOGDIR) | |
| headers = {"User-Agent": "SEED-X Client"} | |
| no_change_btn = gr.Button() | |
| enable_btn = gr.Button(interactive=True) | |
| disable_btn = gr.Button(interactive=False) | |
| conv_seed_llama = conv_seed_llama2 | |
| def get_conv_log_filename(): | |
| t = datetime.datetime.now() | |
| name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") | |
| return name | |
| def get_conv_image_dir(): | |
| name = os.path.join(LOGDIR, 'images') | |
| os.makedirs(name, exist_ok=True) | |
| return name | |
| def get_image_name(image, image_dir=None): | |
| buffer = io.BytesIO() | |
| image.save(buffer, format='PNG') | |
| image_bytes = buffer.getvalue() | |
| md5 = hashlib.md5(image_bytes).hexdigest() | |
| if image_dir is not None: | |
| image_name = os.path.join(image_dir, md5 + '.png') | |
| else: | |
| image_name = md5 + '.png' | |
| return image_name | |
| def resize_image_square(image, target_size=448): | |
| resized_image = image.resize((target_size, target_size)) | |
| return resized_image | |
| def resize_image(image, max_size=512): | |
| width, height = image.size | |
| aspect_ratio = float(width) / float(height) | |
| if width > height: | |
| new_width = max_size | |
| new_height = int(new_width / aspect_ratio) | |
| else: | |
| new_height = max_size | |
| new_width = int(new_height * aspect_ratio) | |
| resized_image = image.resize((new_width, new_height)) | |
| return resized_image | |
| def center_crop_image(image, max_aspect_ratio=1.5): | |
| width, height = image.size | |
| aspect_ratio = max(width, height) / min(width, height) | |
| if aspect_ratio >= max_aspect_ratio: | |
| if width > height: | |
| new_width = int(height * max_aspect_ratio) | |
| left = (width - new_width) // 2 | |
| right = (width + new_width) // 2 | |
| top = 0 | |
| bottom = height | |
| else: | |
| new_height = int(width * max_aspect_ratio) | |
| left = 0 | |
| right = width | |
| top = (height - new_height) // 2 | |
| bottom = (height + new_height) // 2 | |
| cropped_image = image.crop((left, top, right, bottom)) | |
| return cropped_image | |
| else: | |
| return image | |
| def vote_last_response(state, vote_type, request: gr.Request): | |
| with open(get_conv_log_filename(), "a") as fout: | |
| data = { | |
| "tstamp": round(time.time(), 4), | |
| "type": vote_type, | |
| "state": state.dict(), | |
| "ip": request.client.host, | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| def upvote_last_response(state, request: gr.Request): | |
| logger.info(f"upvote. ip: {request.client.host}") | |
| vote_last_response(state, "upvote", request) | |
| return (disable_btn,) * 2 | |
| def downvote_last_response(state, request: gr.Request): | |
| logger.info(f"downvote. ip: {request.client.host}") | |
| vote_last_response(state, "downvote", request) | |
| return (disable_btn,) * 2 | |
| def regenerate(dialog_state, request: gr.Request): | |
| logger.info(f"regenerate. ip: {request.client.host}") | |
| if dialog_state.messages[-1]['role'] == dialog_state.roles[1]: | |
| dialog_state.messages.pop() | |
| return ( | |
| dialog_state, | |
| dialog_state.to_gradio_chatbot(), | |
| ) + (disable_btn,) * 4 | |
| def clear_history(request: gr.Request): | |
| logger.info(f"clear_history. ip: {request.client.host}") | |
| dialog_state = conv_seed_llama.copy() | |
| input_state = init_input_state() | |
| return (dialog_state, input_state, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 | |
| def init_input_state(): | |
| return {'images': [], 'text': ''} | |
| def add_text(dialog_state, input_state, text, request: gr.Request): | |
| logger.info(f"add_text. ip: {request.client.host}.") | |
| if text is None or len(text) == 0: | |
| return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 | |
| input_state['text'] += text | |
| if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]: | |
| dialog_state.messages[-1]['message'] = input_state | |
| else: | |
| dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state}) | |
| print('add_text: ', dialog_state.to_gradio_chatbot()) | |
| return (dialog_state, input_state, "", dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 | |
| def is_blank(image): | |
| image_array = np.array(image) | |
| unique_colors = np.unique(image_array) | |
| print('unique_colors', len(unique_colors)) | |
| return len(unique_colors) == 1 | |
| def add_image(dialog_state, input_state, image, request: gr.Request): | |
| logger.info(f"add_image. ip: {request.client.host}.") | |
| if image is None: | |
| return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (no_change_btn,) * 4 | |
| image = image.convert('RGB') | |
| print('image size:', image.size) | |
| image = center_crop_image(image, max_aspect_ratio=10) | |
| image_dir = get_conv_image_dir() | |
| image_path = get_image_name(image=image, image_dir=image_dir) | |
| if not os.path.exists(image_path): | |
| image.save(image_path) | |
| input_state['images'].append(image_path) | |
| input_state['text'] += IMG_FLAG | |
| if len(dialog_state.messages) > 0 and dialog_state.messages[-1]['role'] == dialog_state.roles[0]: | |
| dialog_state.messages[-1]['message'] = input_state | |
| else: | |
| dialog_state.messages.append({'role': dialog_state.roles[0], 'message': input_state}) | |
| print('add_image:', dialog_state) | |
| return (dialog_state, input_state, None, dialog_state.to_gradio_chatbot()) + (disable_btn,) * 4 | |
| def update_error_msg(chatbot, error_msg): | |
| if len(error_msg) > 0: | |
| info = '\n-------------\nSome errors occurred during response, please clear history and restart.\n' + '\n'.join( | |
| error_msg) | |
| chatbot[-1][-1] = chatbot[-1][-1] + info | |
| return chatbot | |
| def load_demo(request: gr.Request): | |
| logger.info(f"load_demo. ip: {request.client.host}") | |
| dialog_state = conv_seed_llama.copy() | |
| input_state = init_input_state() | |
| return dialog_state, input_state | |
| title = (""" | |
| # SEED-Story | |
| [[Paper]](https://arxiv.org/abs/2407.08683) [[Code]](https://github.com/TencentARC/SEED-Story) | |
| Demo of a multimodal story generation model SEED-Story-George. It is trained on StoryStream-Curious George subset. | |
| SEED-Story is a MLLM capable of generating multimodal long stories consisting of rich and coherent narrative texts, along with images that are consistent in characters and style. | |
| ## Tips: | |
| * Check out the conversation examples (at the bottom) for inspiration. | |
| * You can adjust "Max History Rounds" to try a conversation with up to **three rounds due to insufficient GPU memory**. For more turns, you can download our checkpoints from GitHub and deploy them locally for inference. | |
| * Our demo supports a mix of images and texts as input. You can freely upload an image or enter text, and then click on "Add Image/Text". You can repeat the former step multiple times, and click on "Submit" for model inference at last. | |
| * SEED-Story was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable. | |
| """) | |
| css = """ | |
| img { | |
| font-family: 'Helvetica'; | |
| font-weight: 300; | |
| line-height: 2; | |
| text-align: center; | |
| width: auto; | |
| height: auto; | |
| display: block; | |
| position: relative; | |
| } | |
| img:before { | |
| content: " "; | |
| display: block; | |
| position: absolute; | |
| top: -10px; | |
| left: 0; | |
| height: calc(100% + 10px); | |
| width: 100%; | |
| background-color: rgb(230, 230, 230); | |
| border: 2px dotted rgb(200, 200, 200); | |
| border-radius: 5px; | |
| } | |
| img:after { | |
| content: " "; | |
| display: block; | |
| font-size: 16px; | |
| font-style: normal; | |
| font-family: FontAwesome; | |
| color: rgb(100, 100, 100); | |
| position: absolute; | |
| top: 5px; | |
| left: 0; | |
| width: 100%; | |
| text-align: center; | |
| } | |
| """ | |
| if __name__ == '__main__': | |
| examples_mix = [ | |
| ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/bank.png?raw=true', | |
| 'Can I conntect with an advisor on Sunday?'], | |
| ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/ground.png?raw=true', | |
| 'Is there anything in the image that can protect me from catching the flu virus when I go out? Show me the location.'], | |
| ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/arrow.jpg?raw=true', | |
| 'What is the object pointed by the red arrow?'], | |
| ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/shanghai.png?raw=true', | |
| 'Where was this image taken? Explain your answer.'], | |
| ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/GPT4.png?raw=true', | |
| 'How long does it take to make GPT-4 safer?'], | |
| ['https://github.com/AILab-CVC/SEED-X/blob/main/demos/twitter.png?raw=true', | |
| 'Please provide a comprehensive description of this image.'], | |
| ] | |
| examples_text = [ | |
| ['I want to build a two story cabin in the woods, with many commanding windows. Can you show me a picture?'], | |
| ['Use your imagination to design a concept image for Artificial General Intelligence (AGI). Show me an image.'], | |
| [ | |
| 'Can you design an illustration for โThe Three-Body Problemโ to depict a scene from the novel? Show me a picture.'], | |
| [ | |
| 'My four year old son loves toy trains. Can you design a fancy birthday cake for him? Please generate a picture.'], | |
| [ | |
| 'Generate an image of a portrait of young nordic girl, age 25, freckled skin, neck tatoo, blue eyes 35mm lens, photography, ultra details.'], | |
| ['Generate an impressionist painting of an astronaut in a jungle.'] | |
| ] | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown(title) | |
| dialog_state = gr.State() | |
| input_state = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| image = gr.Image(type='pil', label='input_image') | |
| with gr.Row(): | |
| text = gr.Textbox(lines=5, | |
| show_label=False, | |
| label='input_text', | |
| elem_id='textbox', | |
| placeholder="Enter text and image, and press submit,", container=False) | |
| with gr.Row(): | |
| add_image_btn = gr.Button("Add Image") | |
| add_text_btn = gr.Button("Add Text") | |
| submit_btn = gr.Button("Submit") | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider(minimum=64, | |
| maximum=1024, | |
| value=768, | |
| step=64, | |
| interactive=True, | |
| label="Max Output Tokens") | |
| max_turns = gr.Slider(minimum=1, maximum=3, value=3, step=1, interactive=True, | |
| label="Max History Rounds") | |
| force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation') | |
| force_bbox = gr.Radio(choices=[True, False], value=False, label='Force Bounding Box') | |
| force_polish = gr.Radio(choices=[True, False], value=True, label='Force Polishing Generated Image') | |
| with gr.Column(scale=7): | |
| chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-X-I", height=700) | |
| with gr.Row(): | |
| upvote_btn = gr.Button(value="๐ Upvote", interactive=False) | |
| downvote_btn = gr.Button(value="๐ Downvote", interactive=False) | |
| regenerate_btn = gr.Button(value="๐ Regenerate", interactive=False) | |
| clear_btn = gr.Button(value="๐๏ธ Clear history", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=0.7): | |
| gr.Examples(examples=examples_mix, label='Input examples', inputs=[image, text], cache_examples=False) | |
| with gr.Column(scale=0.3): | |
| gr.Examples(examples=examples_text, label='Input examples', inputs=[text], cache_examples=False) | |
| # Register listeners | |
| btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn] | |
| upvote_btn.click(upvote_last_response, [dialog_state], [upvote_btn, downvote_btn]) | |
| downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn]) | |
| regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then( | |
| http_bot, [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish], | |
| [dialog_state, input_state, chatbot] + btn_list) | |
| add_image_btn.click(add_image, [dialog_state, input_state, image], | |
| [dialog_state, input_state, image, chatbot] + btn_list) | |
| add_text_btn.click(add_text, [dialog_state, input_state, text], | |
| [dialog_state, input_state, text, chatbot] + btn_list) | |
| submit_btn.click( | |
| add_image, [dialog_state, input_state, image], [dialog_state, input_state, image, chatbot] + btn_list).then( | |
| add_text, [dialog_state, input_state, text], | |
| [dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then( | |
| http_bot, | |
| [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish], | |
| [dialog_state, input_state, chatbot] + btn_list) | |
| clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list) | |
| demo.load(load_demo, None, [dialog_state, input_state]) | |
| demo.launch(debug=True) |