Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import pickle | |
| import base64 | |
| import requests | |
| import argparse | |
| import numpy as np | |
| import gradio as gr | |
| from functools import partial | |
| from PIL import Image | |
| SERVER_URL = os.getenv('SERVER_URL') | |
| def get_images(state): | |
| history = '' | |
| for i in range(len(state)): | |
| for j in range(len(state[i])): | |
| history += state[i][j] + '\n' | |
| for image_path in re.findall('image/[0-9,a-z]+\.png', history): | |
| if os.path.exists(image_path): | |
| continue | |
| data = {'method': 'get_image', 'args': [image_path], 'kwargs': {}} | |
| data = base64.b64encode(pickle.dumps(data)).decode('utf-8') | |
| response = requests.post(SERVER_URL, json=data) | |
| image = pickle.loads(base64.b64decode(response.json().encode('utf-8'))) | |
| image.save(image_path) | |
| def bot_request(method, *args, **kwargs): | |
| data = {'method': method, 'args': args, 'kwargs': kwargs} | |
| data = base64.b64encode(pickle.dumps(data)).decode('utf-8') | |
| response = requests.post(SERVER_URL, json=data) | |
| response = pickle.loads(base64.b64decode(response.json().encode('utf-8'))) | |
| if response is not None: | |
| state = response[0] | |
| get_images(state) | |
| return response | |
| def run_image(image, *args, **kwargs): | |
| if image is not None: | |
| width, height = image.size | |
| ratio = min(512 / width, 512 / height) | |
| width_new, height_new = (round(width * ratio), round(height * ratio)) | |
| width_new = int(np.round(width_new / 64.0)) * 64 | |
| height_new = int(np.round(height_new / 64.0)) * 64 | |
| image = image.resize((width_new, height_new)) | |
| image = image.convert('RGB') | |
| return bot_request('run_image', image, *args, **kwargs) | |
| def predict_example(temperature, top_p, max_new_token, keep_last_n_paragraphs, image, text): | |
| state = [] | |
| buffer = '' | |
| chatbot, state, text, buffer = run_image(image, state, text, buffer) | |
| chatbot, state, text, buffer = bot_request( | |
| 'run_text', text, state, temperature, top_p, | |
| max_new_token, keep_last_n_paragraphs, buffer) | |
| return chatbot, state, text, None, buffer | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--temperature', type=float, default=0.0, help='temperature for the llm model') | |
| parser.add_argument('--max_new_tokens', type=int, default=256, help='max number of new tokens to generate') | |
| parser.add_argument('--top_p', type=float, default=1.0, help='top_p for the llm model') | |
| parser.add_argument('--top_k', type=int, default=40, help='top_k for the llm model') | |
| parser.add_argument('--keep_last_n_paragraphs', type=int, default=0, help='keep last n paragraphs in the memory') | |
| args = parser.parse_args() | |
| examples = [ | |
| ['images/example-1.jpg', 'What is unusual about this image?'], | |
| ['images/example-2.jpg', 'Make the image look like a cartoon.'], | |
| ['images/example-3.jpg', 'Segment the tie in the image.'], | |
| ['images/example-4.jpg', 'Generate a man watching a sea based on the pose of the woman.'], | |
| ['images/example-5.jpg', 'Replace the dog with a monkey.'], | |
| ] | |
| if not os.path.exists('image'): | |
| os.makedirs('image') | |
| with gr.Blocks() as demo: | |
| state = gr.Chatbot([], visible=False) | |
| buffer = gr.Textbox('', visible=False) | |
| with gr.Row(): | |
| with gr.Column(scale=0.3): | |
| with gr.Row(): | |
| image = gr.Image(type='pil', label='input image') | |
| with gr.Row(): | |
| txt = gr.Textbox(lines=7, show_label=False, elem_id='textbox', | |
| placeholder='Enter text and press submit, or upload an image').style(container=False) | |
| with gr.Row(): | |
| submit = gr.Button('Submit') | |
| with gr.Row(): | |
| clear = gr.Button('Clear') | |
| with gr.Row(): | |
| llm_name = gr.Radio( | |
| ["Vicuna-13B"], | |
| label="LLM Backend", | |
| value="Vicuna-13B", | |
| interactive=True) | |
| keep_last_n_paragraphs = gr.Slider( | |
| minimum=0, | |
| maximum=3, | |
| value=args.keep_last_n_paragraphs, | |
| step=1, | |
| interactive=True, | |
| label='Remember Last N Paragraphs') | |
| max_new_token = gr.Slider( | |
| minimum=64, | |
| maximum=512, | |
| value=args.max_new_tokens, | |
| step=1, | |
| interactive=True, | |
| label='Max New Tokens') | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=args.temperature, | |
| step=0.1, | |
| interactive=True, | |
| visible=False, | |
| label='Temperature') | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=args.top_p, | |
| step=0.1, | |
| interactive=True, | |
| visible=False, | |
| label='Top P') | |
| with gr.Column(scale=0.7): | |
| chatbot = gr.Chatbot(elem_id='chatbot', label='🦙 GPT4Tools').style(height=690) | |
| image.upload(lambda: '', None, txt) | |
| submit.click(run_image, | |
| [image, state, txt, buffer], | |
| [chatbot, state, txt, buffer]).then( | |
| partial(bot_request, 'run_text'), | |
| [txt, state, temperature, top_p, max_new_token, keep_last_n_paragraphs, buffer], | |
| [chatbot, state, txt, buffer]).then( | |
| lambda: None, None, image) | |
| clear.click(partial(bot_request, 'clear')) | |
| clear.click(lambda: [[], [], '', ''], None, [chatbot, state, txt, buffer]) | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=examples, | |
| fn=partial(predict_example, args.temperature, args.top_p, | |
| args.max_new_tokens, args.keep_last_n_paragraphs), | |
| inputs=[image, txt], | |
| outputs=[chatbot, state, txt, image, buffer], | |
| cache_examples=True, | |
| ) | |
| demo.queue(concurrency_count=6) | |
| demo.launch() | |