Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModel, AutoProcessor | |
| # Set device | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # Load the model and processor | |
| # We use trust_remote_code=True because this model has custom code. | |
| model = AutoModel.from_pretrained( | |
| "unum-cloud/uform-gen2-dpo", | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 | |
| ).to(device) | |
| processor = AutoProcessor.from_pretrained( | |
| "unum-cloud/uform-gen2-dpo", | |
| trust_remote_code=True | |
| ) | |
| def transcribe_image(image): | |
| """ | |
| Generates a caption for the given image. | |
| """ | |
| if image is None: | |
| return "Please upload an image." | |
| prompt = "a photo of" | |
| inputs = processor( | |
| text=[prompt], | |
| images=[image], | |
| return_tensors="pt" | |
| ).to(device) | |
| # FIX: Use the correct key 'images' and ensure its dtype matches the model's dtype | |
| inputs["images"] = inputs["images"].to(model.dtype) | |
| with torch.inference_mode(): | |
| output = model.generate( | |
| **inputs, | |
| do_sample=False, | |
| use_cache=True, | |
| max_new_tokens=128, | |
| eos_token_id=32001, | |
| pad_token_id=processor.tokenizer.pad_token_id | |
| ) | |
| prompt_len = inputs["input_ids"].shape[1] | |
| decoded_text = processor.batch_decode(output[:, prompt_len:])[0] | |
| # Remove the end-of-sequence token | |
| result = decoded_text.replace("<|im_end|>", "").strip() | |
| return result | |
| def visual_question_answer(image, question): | |
| """ | |
| Answers a question about the given image. | |
| """ | |
| if image is None: | |
| return "Please upload an image." | |
| if not question: | |
| return "Please ask a question." | |
| # The model expects the prompt to be in a specific format. | |
| prompt = f"<|im_start|>question\n{question}<|im_end|><|im_start|>answer\n" | |
| inputs = processor( | |
| text=[prompt], | |
| images=[image], | |
| return_tensors="pt" | |
| ).to(device) | |
| # FIX: Use the correct key 'images' and ensure its dtype matches the model's dtype | |
| inputs["images"] = inputs["images"].to(model.dtype) | |
| with torch.inference_mode(): | |
| output = model.generate( | |
| **inputs, | |
| do_sample=False, | |
| use_cache=True, | |
| max_new_tokens=128, | |
| eos_token_id=32001, | |
| pad_token_id=processor.tokenizer.pad_token_id | |
| ) | |
| prompt_len = inputs["input_ids"].shape[1] | |
| decoded_text = processor.batch_decode(output[:, prompt_len:])[0] | |
| # Remove the end-of-sequence token | |
| result = decoded_text.replace("<|im_end|>", "").strip() | |
| return result | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Image Transcription and Visual Question Answering") | |
| gr.Markdown("Powered by the unum-cloud/uform-gen2-dpo model.") | |
| with gr.Tab("Image Transcription"): | |
| with gr.Row(): | |
| transcribe_image_input = gr.Image(type="pil", label="Upload Image") | |
| transcribe_output = gr.Textbox(label="Generated Caption") | |
| transcribe_button = gr.Button("Generate Caption") | |
| with gr.Tab("Visual Question Answering"): | |
| with gr.Row(): | |
| vqa_image_input = gr.Image(type="pil", label="Upload Image") | |
| with gr.Column(): | |
| vqa_question_input = gr.Textbox(label="Ask a question") | |
| vqa_output = gr.Textbox(label="Answer") | |
| vqa_button = gr.Button("Get Answer") | |
| # Connect the functions to the Gradio components | |
| transcribe_button.click( | |
| fn=transcribe_image, | |
| inputs=[transcribe_image_input], | |
| outputs=[transcribe_output] | |
| ) | |
| vqa_button.click( | |
| fn=visual_question_answer, | |
| inputs=[vqa_image_input, vqa_question_input], | |
| outputs=[vqa_output] | |
| ) | |
| demo.launch() |