Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
| import torch | |
| import PIL.Image | |
| from pathlib import Path | |
| import tempfile | |
| import os | |
| import os.path | |
| BLIP_MODEL_ID = "Salesforce/blip2-opt-6.7b" | |
| CAPTION_CSV_DIR=os.path.join(os.getcwd(), 'csv') | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| if device.type != 'cuda': | |
| print(f"You are using {device}. This is much slower than using " | |
| "a CUDA-enabled GPU.") | |
| blip_processor = Blip2Processor.from_pretrained(BLIP_MODEL_ID) | |
| blip_model = Blip2ForConditionalGeneration.from_pretrained(BLIP_MODEL_ID, torch_dtype=torch.float16, device_map="auto").to(device) | |
| def captions_images_to_csv(images: list[PIL.Image.Image], prompt: str) -> str: | |
| caption_map = caption_images(images, prompt) | |
| lines = [k + "," + v + "\n" for k, v in caption_map.items()] | |
| with tempfile.NamedTemporaryFile('w', dir=CAPTION_CSV_DIR, delete=False, suffix=".csv") as f: | |
| f.writelines(lines) | |
| return [f.name, ''.join(lines)] | |
| def caption_images(images: list[PIL.Image.Image], prompt: str) -> dict[str, str]: | |
| results = {} | |
| for image in images: | |
| try: | |
| with PIL.Image.open(image.name).convert('RGB') as i: | |
| inputs = blip_processor(i, text=prompt, return_tensors='pt').to(device, torch.float16) | |
| generated_ids = blip_model.generate(**inputs, max_new_tokens=20) | |
| results[Path(image.name).stem] = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| except Exception as e: | |
| print("An error occured while trying to process image: " + image.name) | |
| print(str(e)) | |
| return results | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Image Caption Bot"): | |
| images_input = gr.File(label= "Input Images", file_count="multiple", file_types=["image"]) | |
| caption_prompt = gr.Textbox(label="Prompt") | |
| caption_images_button = gr.Button("Submit") | |
| image_caption_output_file = gr.File(label="Output CSV") | |
| image_caption_output_label = gr.Textbox(label="Output Data", interactive=False) | |
| caption_images_button.click( | |
| captions_images_to_csv, | |
| inputs=[images_input, caption_prompt], | |
| outputs=[image_caption_output_file, image_caption_output_label]) | |
| print("Launching Gradio") | |
| demo.launch() | |