import gradio as gr from lavis.models import load_model_and_preprocess import torch import argparse if __name__ == '__main__': parser = argparse.ArgumentParser(description="Demo") parser.add_argument("--model-name", default="blip2_vicuna_instruct") parser.add_argument("--model-type", default="vicuna7b") args = parser.parse_args() image_input = gr.Image(type="pil") min_len = gr.Slider( minimum=1, maximum=50, value=1, step=1, interactive=True, label="Min Length", ) max_len = gr.Slider( minimum=10, maximum=500, value=250, step=5, interactive=True, label="Max Length", ) sampling = gr.Radio( choices=["Beam search", "Nucleus sampling"], value="Beam search", label="Text Decoding Method", interactive=True, ) top_p = gr.Slider( minimum=0.5, maximum=1.0, value=0.9, step=0.1, interactive=True, label="Top p", ) beam_size = gr.Slider( minimum=1, maximum=10, value=5, step=1, interactive=True, label="Beam Size", ) len_penalty = gr.Slider( minimum=-1, maximum=2, value=1, step=0.2, interactive=True, label="Length Penalty", ) repetition_penalty = gr.Slider( minimum=-1, maximum=3, value=1, step=0.2, interactive=True, label="Repetition Penalty", ) # prompt_textbox = gr.Textbox(label="Prompt:", placeholder="prompt", lines=2) device = torch.device("cuda") if torch.cuda.is_available() else "cpu" print('Loading model...') model, vis_processors, _ = load_model_and_preprocess( name=args.model_name, model_type=args.model_type, is_eval=True, device=device, ) print('Loading model done!') # def inference(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method, modeltype): def inference(image, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method, modeltype): use_nucleus_sampling = decoding_method == "Nucleus sampling" # print(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, use_nucleus_sampling) image = vis_processors["eval"](image).unsqueeze(0).to(device) samples = { "image": image, # "prompt": prompt, "prompt": "Describe the image in detail and where are the violence objects position in the image (center, left, right, top, bottom)." } output = model.generate( samples, length_penalty=float(len_penalty), repetition_penalty=float(repetition_penalty), num_beams=beam_size, max_length=max_len, min_length=min_len, top_p=top_p, use_nucleus_sampling=use_nucleus_sampling, ) return output[0] gr.Interface( fn=inference, # inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], inputs=[image_input, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], outputs="text", allow_flagging="never", ).launch()