amitkumarjaiswal commited on
Commit
c6a4a94
·
verified ·
1 Parent(s): 99a5800

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -0
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from pathlib import Path
4
+ from PIL import Image
5
+ import torch
6
+ import clip
7
+ import yaml
8
+ import pandas as pd
9
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
10
+ from pprint import pprint as print
11
+
12
+ categories = {}
13
+ # Configuration loading and validation
14
+ def load_config(path):
15
+ try:
16
+ with open(path) as file:
17
+ config = yaml.full_load(file)
18
+ # Validate necessary sections are present
19
+ necessary_keys = ['categories', 'config']
20
+ for key in necessary_keys:
21
+ if key not in config:
22
+ raise ValueError(f'Missing necessary config section: {key}')
23
+ return config
24
+ except FileNotFoundError:
25
+ print("Error: config.yml file not found.")
26
+ raise
27
+ except ValueError as e:
28
+ print(str(e))
29
+ raise
30
+
31
+ config = load_config('config.yml')
32
+ categories = config['categories']
33
+
34
+
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ print(f"Using device: {device}")
37
+
38
+
39
+ # Initialize models and processor
40
+ processor = AutoProcessor.from_pretrained(config['config']['models']['blip']['model_name'])
41
+ blip_model = Blip2ForConditionalGeneration.from_pretrained(config['config']['models']['blip']['model_name'], torch_dtype=torch.float16)
42
+
43
+ blip_model.to(device)
44
+ model, preprocess = clip.load(config['config']['models']['clip']['model_name'], device=device)
45
+
46
+ current_index = 0
47
+
48
+ # Load categories from a YAML configuration
49
+
50
+
51
+ # Precompute category embeddings
52
+ for category_name, category_details in categories.items():
53
+ print(f"Precomputing embeddings for category: {category_name}; {category_details}")
54
+ embeddings_tensor = model.encode_text(clip.tokenize(category_details['description']).to(device))
55
+ category_details['embeddings'] = embeddings_tensor.detach().cpu().numpy()
56
+
57
+ def load_image(path):
58
+ try:
59
+ image = Image.open(path)
60
+ image_input = preprocess(image).unsqueeze(0).to(device)
61
+ return image, image_input
62
+ except Exception as e:
63
+ print(f"Error loading image {path}: {e}")
64
+ return None, None
65
+
66
+ def predict_category(image_input, caption_input=None):
67
+ if image_input is None:
68
+ return None, None
69
+ with torch.no_grad():
70
+ image_features = model.encode_image(image_input)
71
+ if caption_input is not None:
72
+ caption_input = clip.tokenize(caption_input).to(device)
73
+ text_features = model.encode_text(caption_input)
74
+ image_features = torch.cat([image_features, text_features])
75
+ image_features /= image_features.norm(dim=-1, keepdim=True)
76
+ image_features = image_features.cpu().numpy()
77
+ best_category = None
78
+ best_similarity = -1
79
+ for category_name, category_details in categories.items():
80
+ similarity = (image_features * category_details['embeddings']).sum()
81
+ if similarity > best_similarity:
82
+ best_similarity = similarity
83
+ best_category = category_name
84
+ return best_category, image_features
85
+
86
+
87
+ image_dir = Path(config['config']['paths']['images'])
88
+ image_files = [f for f in image_dir.glob('*') if f.suffix.lower() in ['.png', '.jpg', '.jpeg']]
89
+
90
+ images_df = pd.DataFrame(columns=['image_path', 'image_embedding', 'predicted_category', 'generated_text'])
91
+ for image_path in image_files:
92
+ img, image_input = load_image(image_path)
93
+ if img is not None:
94
+ blip_input = processor(img, return_tensors="pt").to(device, torch.float16)
95
+ # Ensure generation settings are compatible
96
+ predicted_ids = blip_model.generate(**blip_input, max_new_tokens=10)
97
+ generated_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0].strip()
98
+
99
+ predicted_category, image_features = predict_category(image_input, generated_text)
100
+ generated_text = generated_text.replace(" ", "_") + image_path.suffix
101
+
102
+ new_row = {
103
+ 'image_path': str(image_path),
104
+ 'image_embedding': image_features if image_features is not None else None,
105
+ 'predicted_category': predicted_category,
106
+ 'generated_text': generated_text
107
+ }
108
+ # Using direct indexing to add to the DataFrame
109
+ index = len(images_df)
110
+ images_df.loc[index] = new_row
111
+
112
+
113
+ print(images_df.head())
114
+ # Gradio interface setup and launch
115
+ def next_image_and_prediction(user_choice):
116
+ global current_index
117
+ images_df.loc[current_index, 'predicted_category'] = user_choice
118
+ current_index = (current_index + 1) % len(images_df)
119
+ if current_index < len(images_df):
120
+ next_img_path = images_df.loc[current_index, 'image_path']
121
+ predicted_category = images_df.loc[current_index, 'predicted_category']
122
+ predicted_filename = images_df.loc[current_index, 'generated_text']
123
+ print(f"Next image: {next_img_path}, Predicted category: {predicted_category}")
124
+ return next_img_path, predicted_category, predicted_filename
125
+ else:
126
+ return None, "No more images"
127
+
128
+ def move_images_to_category_folder():
129
+ for index, row in images_df.iterrows():
130
+ image_path = Path(row['image_path'])
131
+ category_name = row['predicted_category']
132
+ if category_name in categories:
133
+ category_path = Path(categories[category_name]['path'])
134
+ category_dir = Path(config['config']['paths']['output']) / category_path
135
+ category_dir.mkdir(parents=True, exist_ok=True)
136
+ new_image_path = category_dir / row['generated_text']
137
+ image_path.rename(new_image_path)
138
+ print(f"Moved {image_path} to {new_image_path}")
139
+ else:
140
+ print(f"Category {category_name} not found in categories.")
141
+
142
+
143
+
144
+
145
+
146
+ with gr.Blocks() as blocks:
147
+ image_block = gr.Image(label="Image", type="filepath", height=300, width=300)
148
+ filename = gr.Textbox(label="Filename", type="text")
149
+ next_button = gr.Button("Next Image")
150
+ category_dropdown = gr.Dropdown(label="Category", choices=list(categories.keys()), type="value")
151
+ submit_button = gr.Button("Submit")
152
+ submit_button.click(fn=move_images_to_category_folder, inputs=[], outputs=[])
153
+ next_button.click(fn=next_image_and_prediction, inputs=category_dropdown, outputs=[image_block, category_dropdown, filename])
154
+
155
+ if not images_df.empty:
156
+ img_path, predicted_category = images_df.loc[0, ['image_path', 'predicted_category']]
157
+ image_block.value = img_path
158
+ category_dropdown.value = predicted_category
159
+
160
+ blocks.launch()