avin-255's picture
added ui
d907d93 verified
import gradio as gr
import os
import uuid
from PIL import Image
import random
import csv
from zipfile import ZipFile
# BLIP-2 Libraries
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
### πŸ“¦ Load BLIP-2 Model (runs once at launch)
device = "cpu" if torch.cuda.is_available() else "cpu"
print(f"πŸ”₯ Using device: {device}")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
device_map="auto" if device == "cuda" else None,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
### πŸ“‚ Session Utilities
def create_session_folder():
session_id = str(uuid.uuid4())
session_path = os.path.join("/tmp", session_id)
os.makedirs(session_path, exist_ok=True)
return session_path
def save_uploaded_images(images, session_path):
saved_paths = []
for i, image in enumerate(images):
img_path = os.path.join(session_path, f"img_{i}.jpg")
image.save(img_path)
saved_paths.append(img_path)
return saved_paths
### 🧠 BLIP-2 Captioning
def generate_caption_blip2(image_path):
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(device)
output = model.generate(**inputs, max_new_tokens=30)
caption = processor.tokenizer.decode(output[0], skip_special_tokens=True)
return caption
### βš™οΈ Gradio Logic
def handle_upload(images):
if not images:
return None, None, "❌ Please upload at least one image."
session_path = create_session_folder()
saved_image_paths = save_uploaded_images(images, session_path)
# Save paths to session file
with open(os.path.join(session_path, "images.txt"), "w") as f:
for path in saved_image_paths:
f.write(path + "\n")
preview_paths = random.sample(saved_image_paths, min(len(saved_image_paths), 5))
preview_images = [Image.open(path) for path in preview_paths]
return preview_images, session_path, f"βœ… Uploaded {len(saved_image_paths)} images."
def start_labeling(session_path):
if not os.path.exists(session_path):
return None
images_file = os.path.join(session_path, "images.txt")
if not os.path.exists(images_file):
return None
with open(images_file, "r") as f:
image_paths = [line.strip() for line in f.readlines()]
# Labeling step
csv_path = os.path.join(session_path, "labels.csv")
with open(csv_path, mode="w", newline="") as csv_file:
writer = csv.writer(csv_file)
writer.writerow(["filename", "caption"])
for img_path in image_paths:
caption = generate_caption_blip2(img_path)
writer.writerow([os.path.basename(img_path), caption])
print(f"πŸ–ΌοΈ {os.path.basename(img_path)} β†’ {caption}")
# Zip everything
zip_path = os.path.join(session_path, "labeled_output.zip")
with ZipFile(zip_path, "w") as zipf:
for img_path in image_paths:
zipf.write(img_path, arcname=os.path.basename(img_path))
zipf.write(csv_path, arcname="labels.csv")
return zip_path
### πŸš€ Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🏷️ AutoLabeler AI")
gr.Markdown("Upload up to 1000 images. We'll generate captions using BLIP-2 and return a zip file with `labels.csv`.")
with gr.Row():
image_input = gr.File(file_types=["image"], file_count="multiple", label="Upload Images")
upload_button = gr.Button("Upload & Preview")
image_gallery = gr.Gallery(label="Preview", columns=3, height="auto")
session_text = gr.Textbox(label="Status", interactive=False)
session_path_hidden = gr.Textbox(visible=False)
upload_button.click(
handle_upload,
inputs=[image_input],
outputs=[image_gallery, session_path_hidden, session_text]
)
start_button = gr.Button("Start Labeling")
output_zip = gr.File(label="Download Labeled Zip")
start_button.click(
start_labeling,
inputs=[session_path_hidden],
outputs=[output_zip]
)
demo.launch()