Spaces:
Paused
Paused
| import os | |
| os.system('pip install git+https://github.com/IDEA-Research/GroundingDINO.git') | |
| os.system('pip install git+https://github.com/facebookresearch/segment-anything.git') | |
| from visual_foundation_models import * | |
| from langchain.agents.initialize import initialize_agent | |
| from langchain.agents.tools import Tool | |
| from langchain.chains.conversation.memory import ConversationBufferMemory | |
| from langchain.llms.openai import OpenAI | |
| import re | |
| import gradio as gr | |
| import inspect | |
| def cut_dialogue_history(history_memory, keep_last_n_words=400): | |
| if history_memory is None or len(history_memory) == 0: | |
| return history_memory | |
| tokens = history_memory.split() | |
| n_tokens = len(tokens) | |
| print(f"history_memory:{history_memory}, n_tokens: {n_tokens}") | |
| if n_tokens < keep_last_n_words: | |
| return history_memory | |
| paragraphs = history_memory.split('\n') | |
| last_n_tokens = n_tokens | |
| while last_n_tokens >= keep_last_n_words: | |
| last_n_tokens -= len(paragraphs[0].split(' ')) | |
| paragraphs = paragraphs[1:] | |
| return '\n' + '\n'.join(paragraphs) | |
| class ConversationBot: | |
| def __init__(self, load_dict): | |
| # load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...} | |
| print(f"Initializing VisualChatGPT, load_dict={load_dict}") | |
| if 'ImageCaptioning' not in load_dict: | |
| raise ValueError("You have to load ImageCaptioning as a basic function for VisualChatGPT") | |
| self.models = {} | |
| # Load Basic Foundation Models | |
| for class_name, device in load_dict.items(): | |
| self.models[class_name] = globals()[class_name](device=device) | |
| # Load Template Foundation Models | |
| for class_name, module in globals().items(): | |
| if getattr(module, 'template_model', False): | |
| template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if | |
| k != 'self'} | |
| loaded_names = set([type(e).__name__ for e in self.models.values()]) | |
| if template_required_names.issubset(loaded_names): | |
| self.models[class_name] = globals()[class_name]( | |
| **{name: self.models[name] for name in template_required_names}) | |
| self.tools = [] | |
| for instance in self.models.values(): | |
| for e in dir(instance): | |
| if e.startswith('inference'): | |
| func = getattr(instance, e) | |
| self.tools.append(Tool(name=func.name, description=func.description, func=func)) | |
| self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output') | |
| def run_text(self, text, state): | |
| self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500) | |
| res = self.agent({"input": text.strip()}) | |
| res['output'] = res['output'].replace("\\", "/") | |
| response = re.sub('(image/\S*png)', lambda m: f'})*{m.group(0)}*', res['output']) | |
| state = state + [(text, response)] | |
| print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n" | |
| f"Current Memory: {self.agent.memory.buffer}") | |
| return state, state | |
| def run_image(self, image, state, txt, lang): | |
| image_filename = os.path.join('image', f"{str(uuid.uuid4())[:8]}.png") | |
| print("======>Auto Resize Image...") | |
| img = Image.open(image.name) | |
| width, height = img.size | |
| ratio = min(512 / width, 512 / height) | |
| width_new, height_new = (round(width * ratio), round(height * ratio)) | |
| width_new = int(np.round(width_new / 64.0)) * 64 | |
| height_new = int(np.round(height_new / 64.0)) * 64 | |
| img = img.resize((width_new, height_new)) | |
| img = img.convert('RGB') | |
| img.save(image_filename, "PNG") | |
| print(f"Resize image form {width}x{height} to {width_new}x{height_new}") | |
| description = self.models['ImageCaptioning'].inference(image_filename) | |
| if lang == 'Chinese': | |
| Human_prompt = f'\nHuman: 提供一张名为 {image_filename}的图片。它的描述是: {description}。 这些信息帮助你理解这个图像,但是你应该使用工具来完成下面的任务,而不是直接从我的描述中想象。 如果你明白了, 说 \"收到\". \n' | |
| AI_prompt = "收到。 " | |
| else: | |
| Human_prompt = f'\nHuman: provide a figure named {image_filename}. The description is: {description}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n' | |
| AI_prompt = "Received. " | |
| self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt | |
| state = state + [(f"*{image_filename}*", AI_prompt)] | |
| print(f"\nProcessed run_image, Input image: {image_filename}\nCurrent state: {state}\n" | |
| f"Current Memory: {self.agent.memory.buffer}") | |
| return state, state, f'{txt} {image_filename} ' | |
| def init_agent(self, openai_api_key, lang): | |
| self.memory.clear() | |
| if lang=='English': | |
| PREFIX, FORMAT_INSTRUCTIONS, SUFFIX = VISUAL_CHATGPT_PREFIX, VISUAL_CHATGPT_FORMAT_INSTRUCTIONS, VISUAL_CHATGPT_SUFFIX | |
| place = "Enter text and press enter, or upload an image" | |
| label_clear = "Clear" | |
| else: | |
| PREFIX, FORMAT_INSTRUCTIONS, SUFFIX = VISUAL_CHATGPT_PREFIX_CN, VISUAL_CHATGPT_FORMAT_INSTRUCTIONS_CN, VISUAL_CHATGPT_SUFFIX_CN | |
| place = "输入文字并回车,或者上传图片" | |
| label_clear = "清除" | |
| self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key) | |
| self.agent = initialize_agent( | |
| self.tools, | |
| self.llm, | |
| agent="conversational-react-description", | |
| verbose=True, | |
| memory=self.memory, | |
| return_intermediate_steps=True, | |
| agent_kwargs={'prefix': PREFIX, 'format_instructions': FORMAT_INSTRUCTIONS, 'suffix': SUFFIX}, ) | |
| return gr.update(visible = True) | |
| bot = ConversationBot({'Text2Box': 'cuda:0', | |
| 'Segmenting': 'cuda:0', | |
| 'Inpainting': 'cuda:0', | |
| 'Text2Image': 'cuda:0', | |
| 'ImageCaptioning': 'cuda:0', | |
| 'VisualQuestionAnswering': 'cuda:0', | |
| 'Image2Canny': 'cpu', | |
| 'CannyText2Image': 'cuda:0', | |
| 'InstructPix2Pix': 'cuda:0', | |
| 'Image2Depth': 'cpu', | |
| 'DepthText2Image': 'cuda:0', | |
| }) | |
| with gr.Blocks(css="#chatbot {overflow:auto; height:500px;}") as demo: | |
| gr.Markdown("<h3><center>KPMG MULTIMODALGPT</center></h3>") | |
| gr.Markdown( | |
| """ | |
| """ | |
| ) | |
| with gr.Row(): | |
| lang = gr.Radio(choices=['Chinese', 'English'], value='English', label='Language') | |
| openai_api_key_textbox = gr.Textbox( | |
| placeholder="Paste your OpenAI API key here to start Visual ChatGPT(sk-...) and press Enter ↵️", | |
| show_label=False, | |
| lines=1, | |
| type="password", | |
| ) | |
| chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT") | |
| state = gr.State([]) | |
| with gr.Row(visible=False) as input_raws: | |
| with gr.Column(scale=0.7): | |
| txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(container=False) | |
| with gr.Column(scale=0.10, min_width=0): | |
| run = gr.Button("🏃♂️Run") | |
| with gr.Column(scale=0.10, min_width=0): | |
| clear = gr.Button("🔄Clear️") | |
| with gr.Column(scale=0.10, min_width=0): | |
| btn = gr.UploadButton("🖼️Upload", file_types=["image"]) | |
| gr.Examples( | |
| examples=[ "Generate a figure of a cat running in the garden", | |
| "Replace the cat with a dog", | |
| "Remove the dog in this image", | |
| "Can you detect the canny edge of this image?", | |
| "Can you use this canny image to generate an oil painting of a dog", | |
| "Make it like water-color painting", | |
| "What is the background color", | |
| "Describe this image", | |
| "please detect the depth of this image", | |
| "Can you use this depth image to generate a cute dog", | |
| ], | |
| inputs=txt | |
| ) | |
| gr.HTML(''' ''') | |
| openai_api_key_textbox.submit(bot.init_agent, [openai_api_key_textbox, lang], [input_raws]) | |
| txt.submit(bot.run_text, [txt, state], [chatbot, state]) | |
| txt.submit(lambda: "", None, txt) | |
| run.click(bot.run_text, [txt, state], [chatbot, state]) | |
| run.click(lambda: "", None, txt) | |
| btn.upload(bot.run_image, [btn, state, txt, lang], [chatbot, state, txt]) | |
| clear.click(bot.memory.clear) | |
| clear.click(lambda: [], None, chatbot) | |
| clear.click(lambda: [], None, state) | |
| demo.queue(concurrency_count=10).launch(server_name="0.0.0.0", server_port=7860) | |