File size: 3,807 Bytes
14693c4
 
9661f49
14693c4
 
9661f49
14693c4
 
9661f49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14693c4
9661f49
 
854565e
 
bea8c57
9661f49
 
 
 
 
 
 
 
 
a87dbf7
9661f49
 
14693c4
9661f49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854565e
 
bea8c57
9661f49
 
 
 
 
 
 
 
a87dbf7
14693c4
9661f49
 
 
 
 
 
14693c4
9661f49
a87dbf7
9661f49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14693c4
9661f49
 
 
 
14693c4
 
9661f49
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor

# Set device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load the model and processor
# We use trust_remote_code=True because this model has custom code.
model = AutoModel.from_pretrained(
    "unum-cloud/uform-gen2-dpo",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to(device)
processor = AutoProcessor.from_pretrained(
    "unum-cloud/uform-gen2-dpo",
    trust_remote_code=True
)

def transcribe_image(image):
    """
    Generates a caption for the given image.
    """
    if image is None:
        return "Please upload an image."
    
    prompt = "a photo of"
    inputs = processor(
        text=[prompt],
        images=[image],
        return_tensors="pt"
    ).to(device)

    # FIX: Use the correct key 'images' and ensure its dtype matches the model's dtype
    inputs["images"] = inputs["images"].to(model.dtype)

    with torch.inference_mode():
        output = model.generate(
            **inputs,
            do_sample=False,
            use_cache=True,
            max_new_tokens=128,
            eos_token_id=32001,
            pad_token_id=processor.tokenizer.pad_token_id
        )

    prompt_len = inputs["input_ids"].shape[1]
    decoded_text = processor.batch_decode(output[:, prompt_len:])[0]
    
    # Remove the end-of-sequence token
    result = decoded_text.replace("<|im_end|>", "").strip()
    return result

def visual_question_answer(image, question):
    """
    Answers a question about the given image.
    """
    if image is None:
        return "Please upload an image."
    if not question:
        return "Please ask a question."

    # The model expects the prompt to be in a specific format.
    prompt = f"<|im_start|>question\n{question}<|im_end|><|im_start|>answer\n"
    
    inputs = processor(
        text=[prompt],
        images=[image],
        return_tensors="pt"
    ).to(device)

    # FIX: Use the correct key 'images' and ensure its dtype matches the model's dtype
    inputs["images"] = inputs["images"].to(model.dtype)

    with torch.inference_mode():
        output = model.generate(
            **inputs,
            do_sample=False,
            use_cache=True,
            max_new_tokens=128,
            eos_token_id=32001,
            pad_token_id=processor.tokenizer.pad_token_id
        )

    prompt_len = inputs["input_ids"].shape[1]
    decoded_text = processor.batch_decode(output[:, prompt_len:])[0]
    
    # Remove the end-of-sequence token
    result = decoded_text.replace("<|im_end|>", "").strip()
    return result

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Image Transcription and Visual Question Answering")
    gr.Markdown("Powered by the unum-cloud/uform-gen2-dpo model.")

    with gr.Tab("Image Transcription"):
        with gr.Row():
            transcribe_image_input = gr.Image(type="pil", label="Upload Image")
            transcribe_output = gr.Textbox(label="Generated Caption")
        transcribe_button = gr.Button("Generate Caption")

    with gr.Tab("Visual Question Answering"):
        with gr.Row():
            vqa_image_input = gr.Image(type="pil", label="Upload Image")
            with gr.Column():
                vqa_question_input = gr.Textbox(label="Ask a question")
                vqa_output = gr.Textbox(label="Answer")
        vqa_button = gr.Button("Get Answer")

    # Connect the functions to the Gradio components
    transcribe_button.click(
        fn=transcribe_image,
        inputs=[transcribe_image_input],
        outputs=[transcribe_output]
    )
    vqa_button.click(
        fn=visual_question_answer,
        inputs=[vqa_image_input, vqa_question_input],
        outputs=[vqa_output]
    )

demo.launch()