Spaces:
Sleeping
Sleeping
File size: 3,807 Bytes
14693c4 9661f49 14693c4 9661f49 14693c4 9661f49 14693c4 9661f49 854565e bea8c57 9661f49 a87dbf7 9661f49 14693c4 9661f49 854565e bea8c57 9661f49 a87dbf7 14693c4 9661f49 14693c4 9661f49 a87dbf7 9661f49 14693c4 9661f49 14693c4 9661f49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
# Set device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load the model and processor
# We use trust_remote_code=True because this model has custom code.
model = AutoModel.from_pretrained(
"unum-cloud/uform-gen2-dpo",
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to(device)
processor = AutoProcessor.from_pretrained(
"unum-cloud/uform-gen2-dpo",
trust_remote_code=True
)
def transcribe_image(image):
"""
Generates a caption for the given image.
"""
if image is None:
return "Please upload an image."
prompt = "a photo of"
inputs = processor(
text=[prompt],
images=[image],
return_tensors="pt"
).to(device)
# FIX: Use the correct key 'images' and ensure its dtype matches the model's dtype
inputs["images"] = inputs["images"].to(model.dtype)
with torch.inference_mode():
output = model.generate(
**inputs,
do_sample=False,
use_cache=True,
max_new_tokens=128,
eos_token_id=32001,
pad_token_id=processor.tokenizer.pad_token_id
)
prompt_len = inputs["input_ids"].shape[1]
decoded_text = processor.batch_decode(output[:, prompt_len:])[0]
# Remove the end-of-sequence token
result = decoded_text.replace("<|im_end|>", "").strip()
return result
def visual_question_answer(image, question):
"""
Answers a question about the given image.
"""
if image is None:
return "Please upload an image."
if not question:
return "Please ask a question."
# The model expects the prompt to be in a specific format.
prompt = f"<|im_start|>question\n{question}<|im_end|><|im_start|>answer\n"
inputs = processor(
text=[prompt],
images=[image],
return_tensors="pt"
).to(device)
# FIX: Use the correct key 'images' and ensure its dtype matches the model's dtype
inputs["images"] = inputs["images"].to(model.dtype)
with torch.inference_mode():
output = model.generate(
**inputs,
do_sample=False,
use_cache=True,
max_new_tokens=128,
eos_token_id=32001,
pad_token_id=processor.tokenizer.pad_token_id
)
prompt_len = inputs["input_ids"].shape[1]
decoded_text = processor.batch_decode(output[:, prompt_len:])[0]
# Remove the end-of-sequence token
result = decoded_text.replace("<|im_end|>", "").strip()
return result
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image Transcription and Visual Question Answering")
gr.Markdown("Powered by the unum-cloud/uform-gen2-dpo model.")
with gr.Tab("Image Transcription"):
with gr.Row():
transcribe_image_input = gr.Image(type="pil", label="Upload Image")
transcribe_output = gr.Textbox(label="Generated Caption")
transcribe_button = gr.Button("Generate Caption")
with gr.Tab("Visual Question Answering"):
with gr.Row():
vqa_image_input = gr.Image(type="pil", label="Upload Image")
with gr.Column():
vqa_question_input = gr.Textbox(label="Ask a question")
vqa_output = gr.Textbox(label="Answer")
vqa_button = gr.Button("Get Answer")
# Connect the functions to the Gradio components
transcribe_button.click(
fn=transcribe_image,
inputs=[transcribe_image_input],
outputs=[transcribe_output]
)
vqa_button.click(
fn=visual_question_answer,
inputs=[vqa_image_input, vqa_question_input],
outputs=[vqa_output]
)
demo.launch() |