gen3-visual / app.py
sajofu's picture
Update app.py
854565e verified
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor
# Set device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Load the model and processor
# We use trust_remote_code=True because this model has custom code.
model = AutoModel.from_pretrained(
"unum-cloud/uform-gen2-dpo",
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to(device)
processor = AutoProcessor.from_pretrained(
"unum-cloud/uform-gen2-dpo",
trust_remote_code=True
)
def transcribe_image(image):
"""
Generates a caption for the given image.
"""
if image is None:
return "Please upload an image."
prompt = "a photo of"
inputs = processor(
text=[prompt],
images=[image],
return_tensors="pt"
).to(device)
# FIX: Use the correct key 'images' and ensure its dtype matches the model's dtype
inputs["images"] = inputs["images"].to(model.dtype)
with torch.inference_mode():
output = model.generate(
**inputs,
do_sample=False,
use_cache=True,
max_new_tokens=128,
eos_token_id=32001,
pad_token_id=processor.tokenizer.pad_token_id
)
prompt_len = inputs["input_ids"].shape[1]
decoded_text = processor.batch_decode(output[:, prompt_len:])[0]
# Remove the end-of-sequence token
result = decoded_text.replace("<|im_end|>", "").strip()
return result
def visual_question_answer(image, question):
"""
Answers a question about the given image.
"""
if image is None:
return "Please upload an image."
if not question:
return "Please ask a question."
# The model expects the prompt to be in a specific format.
prompt = f"<|im_start|>question\n{question}<|im_end|><|im_start|>answer\n"
inputs = processor(
text=[prompt],
images=[image],
return_tensors="pt"
).to(device)
# FIX: Use the correct key 'images' and ensure its dtype matches the model's dtype
inputs["images"] = inputs["images"].to(model.dtype)
with torch.inference_mode():
output = model.generate(
**inputs,
do_sample=False,
use_cache=True,
max_new_tokens=128,
eos_token_id=32001,
pad_token_id=processor.tokenizer.pad_token_id
)
prompt_len = inputs["input_ids"].shape[1]
decoded_text = processor.batch_decode(output[:, prompt_len:])[0]
# Remove the end-of-sequence token
result = decoded_text.replace("<|im_end|>", "").strip()
return result
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image Transcription and Visual Question Answering")
gr.Markdown("Powered by the unum-cloud/uform-gen2-dpo model.")
with gr.Tab("Image Transcription"):
with gr.Row():
transcribe_image_input = gr.Image(type="pil", label="Upload Image")
transcribe_output = gr.Textbox(label="Generated Caption")
transcribe_button = gr.Button("Generate Caption")
with gr.Tab("Visual Question Answering"):
with gr.Row():
vqa_image_input = gr.Image(type="pil", label="Upload Image")
with gr.Column():
vqa_question_input = gr.Textbox(label="Ask a question")
vqa_output = gr.Textbox(label="Answer")
vqa_button = gr.Button("Get Answer")
# Connect the functions to the Gradio components
transcribe_button.click(
fn=transcribe_image,
inputs=[transcribe_image_input],
outputs=[transcribe_output]
)
vqa_button.click(
fn=visual_question_answer,
inputs=[vqa_image_input, vqa_question_input],
outputs=[vqa_output]
)
demo.launch()