Spaces:
Runtime error
Runtime error
File size: 6,480 Bytes
c6a4a94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import gradio as gr
import os
from pathlib import Path
from PIL import Image
import torch
import clip
import yaml
import pandas as pd
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from pprint import pprint as print
categories = {}
# Configuration loading and validation
def load_config(path):
try:
with open(path) as file:
config = yaml.full_load(file)
# Validate necessary sections are present
necessary_keys = ['categories', 'config']
for key in necessary_keys:
if key not in config:
raise ValueError(f'Missing necessary config section: {key}')
return config
except FileNotFoundError:
print("Error: config.yml file not found.")
raise
except ValueError as e:
print(str(e))
raise
config = load_config('config.yml')
categories = config['categories']
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Initialize models and processor
processor = AutoProcessor.from_pretrained(config['config']['models']['blip']['model_name'])
blip_model = Blip2ForConditionalGeneration.from_pretrained(config['config']['models']['blip']['model_name'], torch_dtype=torch.float16)
blip_model.to(device)
model, preprocess = clip.load(config['config']['models']['clip']['model_name'], device=device)
current_index = 0
# Load categories from a YAML configuration
# Precompute category embeddings
for category_name, category_details in categories.items():
print(f"Precomputing embeddings for category: {category_name}; {category_details}")
embeddings_tensor = model.encode_text(clip.tokenize(category_details['description']).to(device))
category_details['embeddings'] = embeddings_tensor.detach().cpu().numpy()
def load_image(path):
try:
image = Image.open(path)
image_input = preprocess(image).unsqueeze(0).to(device)
return image, image_input
except Exception as e:
print(f"Error loading image {path}: {e}")
return None, None
def predict_category(image_input, caption_input=None):
if image_input is None:
return None, None
with torch.no_grad():
image_features = model.encode_image(image_input)
if caption_input is not None:
caption_input = clip.tokenize(caption_input).to(device)
text_features = model.encode_text(caption_input)
image_features = torch.cat([image_features, text_features])
image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.cpu().numpy()
best_category = None
best_similarity = -1
for category_name, category_details in categories.items():
similarity = (image_features * category_details['embeddings']).sum()
if similarity > best_similarity:
best_similarity = similarity
best_category = category_name
return best_category, image_features
image_dir = Path(config['config']['paths']['images'])
image_files = [f for f in image_dir.glob('*') if f.suffix.lower() in ['.png', '.jpg', '.jpeg']]
images_df = pd.DataFrame(columns=['image_path', 'image_embedding', 'predicted_category', 'generated_text'])
for image_path in image_files:
img, image_input = load_image(image_path)
if img is not None:
blip_input = processor(img, return_tensors="pt").to(device, torch.float16)
# Ensure generation settings are compatible
predicted_ids = blip_model.generate(**blip_input, max_new_tokens=10)
generated_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
predicted_category, image_features = predict_category(image_input, generated_text)
generated_text = generated_text.replace(" ", "_") + image_path.suffix
new_row = {
'image_path': str(image_path),
'image_embedding': image_features if image_features is not None else None,
'predicted_category': predicted_category,
'generated_text': generated_text
}
# Using direct indexing to add to the DataFrame
index = len(images_df)
images_df.loc[index] = new_row
print(images_df.head())
# Gradio interface setup and launch
def next_image_and_prediction(user_choice):
global current_index
images_df.loc[current_index, 'predicted_category'] = user_choice
current_index = (current_index + 1) % len(images_df)
if current_index < len(images_df):
next_img_path = images_df.loc[current_index, 'image_path']
predicted_category = images_df.loc[current_index, 'predicted_category']
predicted_filename = images_df.loc[current_index, 'generated_text']
print(f"Next image: {next_img_path}, Predicted category: {predicted_category}")
return next_img_path, predicted_category, predicted_filename
else:
return None, "No more images"
def move_images_to_category_folder():
for index, row in images_df.iterrows():
image_path = Path(row['image_path'])
category_name = row['predicted_category']
if category_name in categories:
category_path = Path(categories[category_name]['path'])
category_dir = Path(config['config']['paths']['output']) / category_path
category_dir.mkdir(parents=True, exist_ok=True)
new_image_path = category_dir / row['generated_text']
image_path.rename(new_image_path)
print(f"Moved {image_path} to {new_image_path}")
else:
print(f"Category {category_name} not found in categories.")
with gr.Blocks() as blocks:
image_block = gr.Image(label="Image", type="filepath", height=300, width=300)
filename = gr.Textbox(label="Filename", type="text")
next_button = gr.Button("Next Image")
category_dropdown = gr.Dropdown(label="Category", choices=list(categories.keys()), type="value")
submit_button = gr.Button("Submit")
submit_button.click(fn=move_images_to_category_folder, inputs=[], outputs=[])
next_button.click(fn=next_image_and_prediction, inputs=category_dropdown, outputs=[image_block, category_dropdown, filename])
if not images_df.empty:
img_path, predicted_category = images_df.loc[0, ['image_path', 'predicted_category']]
image_block.value = img_path
category_dropdown.value = predicted_category
blocks.launch() |