Spaces:
Sleeping
Sleeping
File size: 4,193 Bytes
d907d93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import os
import uuid
from PIL import Image
import random
import csv
from zipfile import ZipFile
# BLIP-2 Libraries
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
### π¦ Load BLIP-2 Model (runs once at launch)
device = "cpu" if torch.cuda.is_available() else "cpu"
print(f"π₯ Using device: {device}")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
device_map="auto" if device == "cuda" else None,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
### π Session Utilities
def create_session_folder():
session_id = str(uuid.uuid4())
session_path = os.path.join("/tmp", session_id)
os.makedirs(session_path, exist_ok=True)
return session_path
def save_uploaded_images(images, session_path):
saved_paths = []
for i, image in enumerate(images):
img_path = os.path.join(session_path, f"img_{i}.jpg")
image.save(img_path)
saved_paths.append(img_path)
return saved_paths
### π§ BLIP-2 Captioning
def generate_caption_blip2(image_path):
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
output = model.generate(**inputs, max_new_tokens=30)
caption = processor.tokenizer.decode(output[0], skip_special_tokens=True)
return caption
### βοΈ Gradio Logic
def handle_upload(images):
if not images:
return None, None, "β Please upload at least one image."
session_path = create_session_folder()
saved_image_paths = save_uploaded_images(images, session_path)
# Save paths to session file
with open(os.path.join(session_path, "images.txt"), "w") as f:
for path in saved_image_paths:
f.write(path + "\n")
preview_paths = random.sample(saved_image_paths, min(len(saved_image_paths), 5))
preview_images = [Image.open(path) for path in preview_paths]
return preview_images, session_path, f"β
Uploaded {len(saved_image_paths)} images."
def start_labeling(session_path):
if not os.path.exists(session_path):
return None
images_file = os.path.join(session_path, "images.txt")
if not os.path.exists(images_file):
return None
with open(images_file, "r") as f:
image_paths = [line.strip() for line in f.readlines()]
# Labeling step
csv_path = os.path.join(session_path, "labels.csv")
with open(csv_path, mode="w", newline="") as csv_file:
writer = csv.writer(csv_file)
writer.writerow(["filename", "caption"])
for img_path in image_paths:
caption = generate_caption_blip2(img_path)
writer.writerow([os.path.basename(img_path), caption])
print(f"πΌοΈ {os.path.basename(img_path)} β {caption}")
# Zip everything
zip_path = os.path.join(session_path, "labeled_output.zip")
with ZipFile(zip_path, "w") as zipf:
for img_path in image_paths:
zipf.write(img_path, arcname=os.path.basename(img_path))
zipf.write(csv_path, arcname="labels.csv")
return zip_path
### π Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## π·οΈ AutoLabeler AI")
gr.Markdown("Upload up to 1000 images. We'll generate captions using BLIP-2 and return a zip file with `labels.csv`.")
with gr.Row():
image_input = gr.File(file_types=["image"], file_count="multiple", label="Upload Images")
upload_button = gr.Button("Upload & Preview")
image_gallery = gr.Gallery(label="Preview", columns=3, height="auto")
session_text = gr.Textbox(label="Status", interactive=False)
session_path_hidden = gr.Textbox(visible=False)
upload_button.click(
handle_upload,
inputs=[image_input],
outputs=[image_gallery, session_path_hidden, session_text]
)
start_button = gr.Button("Start Labeling")
output_zip = gr.File(label="Download Labeled Zip")
start_button.click(
start_labeling,
inputs=[session_path_hidden],
outputs=[output_zip]
)
demo.launch()
|