Spaces:
Build error
Build error
| import numpy as np | |
| import random | |
| from PIL import Image | |
| from ram.models import tag2text | |
| from ram import inference_tag2text as inference | |
| from ram import get_transform | |
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| evice = torch.device('cpu') | |
| model = tag2text(pretrained='./tag2text_swin_14m.pth', image_size=384, vit='swin_b') | |
| model.threshold = 0.68 | |
| model.eval() | |
| def tag_image(input_image): | |
| transform = get_transform(image_size=384) | |
| if isinstance(input_image, Image.Image): | |
| img = input_image | |
| else: | |
| # Convert Gradio Image datatype (NumPy array) to PIL Image | |
| img = Image.fromarray(input_image) | |
| # Process the image | |
| print(f"Start processing, image size {img.size}") | |
| image = transform(img).unsqueeze(0) | |
| # Generate Tags and Captions | |
| res = inference(image, model) | |
| tags = res[0].strip(' ').replace(' ', ' ') | |
| caption = res[2] | |
| print(tags, caption) | |
| return tags, caption | |
| # Interface for the demo | |
| inputs = gr.inputs.Image() | |
| outputs = [gr.outputs.Textbox(label='Tags'), gr.outputs.Textbox(label='Caption')] | |
| # Launch the Gradio app | |
| gr.Interface( | |
| fn=tag_image, | |
| inputs=inputs, | |
| outputs=outputs, | |
| title="Tags and Captioning using Tag2Text", | |
| description="Upload an image and see its tags and captions in the corresponding output boxes", | |
| theme=gr.themes.Soft(), | |
| live=True, | |
| ).launch(enable_queue=True, debug=True) |