Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from lavis.models import load_model_and_preprocess | |
| import torch | |
| import argparse | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description="Demo") | |
| parser.add_argument("--model-name", default="blip2_vicuna_instruct") | |
| parser.add_argument("--model-type", default="vicuna7b") | |
| args = parser.parse_args() | |
| image_input = gr.Image(type="pil") | |
| min_len = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=1, | |
| step=1, | |
| interactive=True, | |
| label="Min Length", | |
| ) | |
| max_len = gr.Slider( | |
| minimum=10, | |
| maximum=500, | |
| value=250, | |
| step=5, | |
| interactive=True, | |
| label="Max Length", | |
| ) | |
| sampling = gr.Radio( | |
| choices=["Beam search", "Nucleus sampling"], | |
| value="Beam search", | |
| label="Text Decoding Method", | |
| interactive=True, | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.5, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| interactive=True, | |
| label="Top p", | |
| ) | |
| beam_size = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| interactive=True, | |
| label="Beam Size", | |
| ) | |
| len_penalty = gr.Slider( | |
| minimum=-1, | |
| maximum=2, | |
| value=1, | |
| step=0.2, | |
| interactive=True, | |
| label="Length Penalty", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=-1, | |
| maximum=3, | |
| value=1, | |
| step=0.2, | |
| interactive=True, | |
| label="Repetition Penalty", | |
| ) | |
| # prompt_textbox = gr.Textbox(label="Prompt:", placeholder="prompt", lines=2) | |
| device = torch.device("cuda") if torch.cuda.is_available() else "cpu" | |
| print('Loading model...') | |
| model, vis_processors, _ = load_model_and_preprocess( | |
| name=args.model_name, | |
| model_type=args.model_type, | |
| is_eval=True, | |
| device=device, | |
| ) | |
| print('Loading model done!') | |
| # def inference(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method, modeltype): | |
| def inference(image, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method, modeltype): | |
| use_nucleus_sampling = decoding_method == "Nucleus sampling" | |
| # print(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, use_nucleus_sampling) | |
| image = vis_processors["eval"](image).unsqueeze(0).to(device) | |
| samples = { | |
| "image": image, | |
| # "prompt": prompt, | |
| "prompt": "Describe the image in detail and where are the violence objects position in the image (center, left, right, top, bottom)." | |
| } | |
| output = model.generate( | |
| samples, | |
| length_penalty=float(len_penalty), | |
| repetition_penalty=float(repetition_penalty), | |
| num_beams=beam_size, | |
| max_length=max_len, | |
| min_length=min_len, | |
| top_p=top_p, | |
| use_nucleus_sampling=use_nucleus_sampling, | |
| ) | |
| return output[0] | |
| gr.Interface( | |
| fn=inference, | |
| # inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], | |
| inputs=[image_input, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], | |
| outputs="text", | |
| allow_flagging="never", | |
| ).launch() | |