import gradio as gr import torch from PIL import Image from transformers import AutoModel, AutoProcessor # Set device device = "cuda:0" if torch.cuda.is_available() else "cpu" # Load the model and processor # We use trust_remote_code=True because this model has custom code. model = AutoModel.from_pretrained( "unum-cloud/uform-gen2-dpo", trust_remote_code=True, torch_dtype=torch.bfloat16 ).to(device) processor = AutoProcessor.from_pretrained( "unum-cloud/uform-gen2-dpo", trust_remote_code=True ) def transcribe_image(image): """ Generates a caption for the given image. """ if image is None: return "Please upload an image." prompt = "a photo of" inputs = processor( text=[prompt], images=[image], return_tensors="pt" ).to(device) # FIX: Use the correct key 'images' and ensure its dtype matches the model's dtype inputs["images"] = inputs["images"].to(model.dtype) with torch.inference_mode(): output = model.generate( **inputs, do_sample=False, use_cache=True, max_new_tokens=128, eos_token_id=32001, pad_token_id=processor.tokenizer.pad_token_id ) prompt_len = inputs["input_ids"].shape[1] decoded_text = processor.batch_decode(output[:, prompt_len:])[0] # Remove the end-of-sequence token result = decoded_text.replace("<|im_end|>", "").strip() return result def visual_question_answer(image, question): """ Answers a question about the given image. """ if image is None: return "Please upload an image." if not question: return "Please ask a question." # The model expects the prompt to be in a specific format. prompt = f"<|im_start|>question\n{question}<|im_end|><|im_start|>answer\n" inputs = processor( text=[prompt], images=[image], return_tensors="pt" ).to(device) # FIX: Use the correct key 'images' and ensure its dtype matches the model's dtype inputs["images"] = inputs["images"].to(model.dtype) with torch.inference_mode(): output = model.generate( **inputs, do_sample=False, use_cache=True, max_new_tokens=128, eos_token_id=32001, pad_token_id=processor.tokenizer.pad_token_id ) prompt_len = inputs["input_ids"].shape[1] decoded_text = processor.batch_decode(output[:, prompt_len:])[0] # Remove the end-of-sequence token result = decoded_text.replace("<|im_end|>", "").strip() return result # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Image Transcription and Visual Question Answering") gr.Markdown("Powered by the unum-cloud/uform-gen2-dpo model.") with gr.Tab("Image Transcription"): with gr.Row(): transcribe_image_input = gr.Image(type="pil", label="Upload Image") transcribe_output = gr.Textbox(label="Generated Caption") transcribe_button = gr.Button("Generate Caption") with gr.Tab("Visual Question Answering"): with gr.Row(): vqa_image_input = gr.Image(type="pil", label="Upload Image") with gr.Column(): vqa_question_input = gr.Textbox(label="Ask a question") vqa_output = gr.Textbox(label="Answer") vqa_button = gr.Button("Get Answer") # Connect the functions to the Gradio components transcribe_button.click( fn=transcribe_image, inputs=[transcribe_image_input], outputs=[transcribe_output] ) vqa_button.click( fn=visual_question_answer, inputs=[vqa_image_input, vqa_question_input], outputs=[vqa_output] ) demo.launch()