Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import gradio.components as grc | |
| import torch | |
| from lavis.models import load_model_and_preprocess | |
| from lavis.processors import load_processor | |
| # setup device to use | |
| device = torch.device("cuda") if torch.cuda.is_available() else "cpu" | |
| model, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True) | |
| def predict(raw_image, caption): | |
| raw_image = raw_image.convert("RGB") | |
| img = vis_processors["eval"](raw_image).unsqueeze(0).to(device) | |
| txt = text_processors["eval"](caption) | |
| itm_output = model({"image": img, "text_input": txt}, match_head="itm") | |
| itm_scores = torch.nn.functional.softmax(itm_output, dim=1) | |
| itm_score = itm_scores[:, 1].item() | |
| itc_score = model({"image": img, "text_input": txt}, match_head='itc') | |
| return '%.3f' % itm_score, '%.4f' % itc_score | |
| app = gr.Interface(fn=predict, inputs=[grc.Image(type="pil"), grc.Textbox()], outputs=[grc.Text(label="itm score"), grc.Text(label="itc score")]) | |
| app.launch() | |