Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import uuid | |
| from PIL import Image | |
| import random | |
| import csv | |
| from zipfile import ZipFile | |
| # BLIP-2 Libraries | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
| import torch | |
| ### π¦ Load BLIP-2 Model (runs once at launch) | |
| device = "cpu" if torch.cuda.is_available() else "cpu" | |
| print(f"π₯ Using device: {device}") | |
| processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
| model = Blip2ForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip2-opt-2.7b", | |
| device_map="auto" if device == "cuda" else None, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ).to(device) | |
| ### π Session Utilities | |
| def create_session_folder(): | |
| session_id = str(uuid.uuid4()) | |
| session_path = os.path.join("/tmp", session_id) | |
| os.makedirs(session_path, exist_ok=True) | |
| return session_path | |
| def save_uploaded_images(images, session_path): | |
| saved_paths = [] | |
| for i, image in enumerate(images): | |
| img_path = os.path.join(session_path, f"img_{i}.jpg") | |
| image.save(img_path) | |
| saved_paths.append(img_path) | |
| return saved_paths | |
| ### π§ BLIP-2 Captioning | |
| def generate_caption_blip2(image_path): | |
| image = Image.open(image_path).convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| output = model.generate(**inputs, max_new_tokens=30) | |
| caption = processor.tokenizer.decode(output[0], skip_special_tokens=True) | |
| return caption | |
| ### βοΈ Gradio Logic | |
| def handle_upload(images): | |
| if not images: | |
| return None, None, "β Please upload at least one image." | |
| session_path = create_session_folder() | |
| saved_image_paths = save_uploaded_images(images, session_path) | |
| # Save paths to session file | |
| with open(os.path.join(session_path, "images.txt"), "w") as f: | |
| for path in saved_image_paths: | |
| f.write(path + "\n") | |
| preview_paths = random.sample(saved_image_paths, min(len(saved_image_paths), 5)) | |
| preview_images = [Image.open(path) for path in preview_paths] | |
| return preview_images, session_path, f"β Uploaded {len(saved_image_paths)} images." | |
| def start_labeling(session_path): | |
| if not os.path.exists(session_path): | |
| return None | |
| images_file = os.path.join(session_path, "images.txt") | |
| if not os.path.exists(images_file): | |
| return None | |
| with open(images_file, "r") as f: | |
| image_paths = [line.strip() for line in f.readlines()] | |
| # Labeling step | |
| csv_path = os.path.join(session_path, "labels.csv") | |
| with open(csv_path, mode="w", newline="") as csv_file: | |
| writer = csv.writer(csv_file) | |
| writer.writerow(["filename", "caption"]) | |
| for img_path in image_paths: | |
| caption = generate_caption_blip2(img_path) | |
| writer.writerow([os.path.basename(img_path), caption]) | |
| print(f"πΌοΈ {os.path.basename(img_path)} β {caption}") | |
| # Zip everything | |
| zip_path = os.path.join(session_path, "labeled_output.zip") | |
| with ZipFile(zip_path, "w") as zipf: | |
| for img_path in image_paths: | |
| zipf.write(img_path, arcname=os.path.basename(img_path)) | |
| zipf.write(csv_path, arcname="labels.csv") | |
| return zip_path | |
| ### π Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π·οΈ AutoLabeler AI") | |
| gr.Markdown("Upload up to 1000 images. We'll generate captions using BLIP-2 and return a zip file with `labels.csv`.") | |
| with gr.Row(): | |
| image_input = gr.File(file_types=["image"], file_count="multiple", label="Upload Images") | |
| upload_button = gr.Button("Upload & Preview") | |
| image_gallery = gr.Gallery(label="Preview", columns=3, height="auto") | |
| session_text = gr.Textbox(label="Status", interactive=False) | |
| session_path_hidden = gr.Textbox(visible=False) | |
| upload_button.click( | |
| handle_upload, | |
| inputs=[image_input], | |
| outputs=[image_gallery, session_path_hidden, session_text] | |
| ) | |
| start_button = gr.Button("Start Labeling") | |
| output_zip = gr.File(label="Download Labeled Zip") | |
| start_button.click( | |
| start_labeling, | |
| inputs=[session_path_hidden], | |
| outputs=[output_zip] | |
| ) | |
| demo.launch() | |