File size: 2,838 Bytes
d56b9d9
1c1b97a
39d9406
d56b9d9
 
39d9406
d56b9d9
 
 
1c1b97a
5bebd85
 
 
 
 
 
 
1c1b97a
5bebd85
 
 
 
 
 
 
 
 
 
 
d56b9d9
1c1b97a
 
 
4c71b8b
 
d56b9d9
1c1b97a
d56b9d9
1c1b97a
d56b9d9
 
39d9406
 
55d79e2
5bebd85
 
 
 
 
 
 
55d79e2
 
 
5bebd85
39d9406
 
 
 
 
 
 
 
 
 
 
 
55d79e2
39d9406
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gc
from functools import partial
import gradio as gr
from PIL.Image import Image
from transformers import AutoProcessor, BlipForConditionalGeneration
from utils import get_pytorch_device, spaces_gpu, request_image


@spaces_gpu
def image_to_text(model: str, image: Image) -> list[str]:
    """Generate text captions for an image using BLIP model.

    This function uses a BLIP (Bootstrapping Language-Image Pre-training) model
    to generate multiple caption candidates for the input image. The model is
    loaded, inference is performed, and then cleaned up to free GPU memory.

    Args:
        model: Hugging Face model ID to use for image captioning.
        image: PIL Image object to generate captions for.

    Returns:
        List of string captions describing the image.

    Note:
        - Uses safetensors for secure model loading.
        - Automatically selects the best available device (CUDA/XPU/MPS/CPU).
        - Cleans up model and GPU memory after inference.
        - Uses beam search with 3 beams, max length 20, min length 5.
    """
    pytorch_device = get_pytorch_device()
    processor = AutoProcessor.from_pretrained(model)
    model_instance = BlipForConditionalGeneration.from_pretrained(
        model,
        use_safetensors=True # Use safetensors to avoid torch.load restriction.
    ).to(pytorch_device)
    inputs = processor(images=image, return_tensors="pt").to(pytorch_device)
    generated_ids = model_instance.generate(pixel_values=inputs.pixel_values, num_beams=3, max_length=20, min_length=5)
    results = processor.batch_decode(generated_ids, skip_special_tokens=True)
    del model_instance, inputs
    gc.collect()
    return results


def create_image_to_text_tab(model: str):
    """Create the image-to-text captioning tab in the Gradio interface.
    
    This function sets up all UI components for image captioning, including:
    - URL input textbox for fetching images from the web
    - Button to retrieve image from URL
    - Image preview component
    - Caption button and output list
    
    Args:
        model: Hugging Face model ID to use for image captioning.
    """
    gr.Markdown("Generate a text description of an image.")
    image_to_text_url_input = gr.Textbox(label="Image URL")
    image_to_text_image_request_button = gr.Button("Get Image")
    image_to_text_image_input = gr.Image(label="Image", type="pil")
    image_to_text_image_request_button.click(
        fn=request_image,
        inputs=image_to_text_url_input,
        outputs=image_to_text_image_input
    )
    image_to_text_button = gr.Button("Caption")
    image_to_text_output = gr.List(label="Captions", headers=["Caption"])
    image_to_text_button.click(
        fn=partial(image_to_text, model),
        inputs=image_to_text_image_input,
        outputs=image_to_text_output
    )