Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import InterpolationMode | |
| from models.blip import blip_decoder | |
| MEAN = torch.tensor((0.48145466, 0.4578275 , 0.40821073)) | |
| STD = torch.tensor((0.26862954, 0.26130258, 0.27577711)) | |
| # Normalize function | |
| def normalize(img_tensor): | |
| img_tensor = img_tensor.squeeze(0) | |
| mean = MEAN.to(img_tensor.device).view(-1, 1, 1) | |
| std = STD.to(img_tensor.device).view(-1, 1, 1) | |
| img_tensor = (img_tensor - mean) / std | |
| img_tensor = img_tensor.unsqueeze(0) | |
| return img_tensor | |
| # Preprocess function | |
| def preprocess_img(raw_img, img_size): | |
| img = raw_img.convert("RGB") | |
| transform = transforms.Compose([ | |
| transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC), | |
| transforms.ToTensor() | |
| ]) | |
| img = transform(img).unsqueeze(0) | |
| return img | |
| # Hyperparameters | |
| IMG_SIZE = 384 | |
| MODEL_URL = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model | |
| model = blip_decoder(pretrained=MODEL_URL, image_size=IMG_SIZE, vit='base') | |
| model.eval() | |
| model = model.to(DEVICE) | |
| # Function to generate caption | |
| def generate_caption(user_image): | |
| img = preprocess_img(user_image, IMG_SIZE).to(DEVICE) | |
| img_norm = normalize(img) | |
| with torch.no_grad(): | |
| caption = model.generate(img_norm, sample=False, num_beams=3, max_length=20, min_length=5)[0] | |
| return caption | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=generate_caption, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Textbox(label="Generated Caption"), | |
| title="BLIP Image Captioning (Base Model)", | |
| description="This model is implemented according to the official BLIP repository: https://github.com/salesforce/BLIP" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |