Spaces:
Build error
Build error
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| processor = AutoProcessor.from_pretrained("microsoft/git-base") | |
| model = AutoModelForCausalLM.from_pretrained("sam749/sd-portrait-caption").to(device) | |
| def generate_captions(images, max_length=200): | |
| # prepare image for the model | |
| inputs = processor(images=images, return_tensors="pt").to(device) | |
| pixel_values = inputs.pixel_values | |
| generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length) | |
| generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
| return generated_caption | |
| def generate_caption(image, max_length=200): | |
| return generate_captions([image], max_length)[0] | |
| image_input = gr.Image(type="pil", label="Upload Image", height=400) | |
| max_length_slider = gr.Slider(minimum=10, maximum=400, value=200, step=8, label="Max Length") | |
| caption_output = gr.Textbox(label="Generated Caption") | |
| demo = gr.Interface( | |
| fn=generate_caption, | |
| inputs=[image_input, max_length_slider], | |
| outputs=caption_output, | |
| theme="gradio/monochrome", | |
| title="Stable Diffusion Portrait Captioner", | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |