safinal commited on
Commit
af5379c
·
verified ·
1 Parent(s): 566c156

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -27
app.py CHANGED
@@ -5,20 +5,18 @@ from PIL import Image
5
  import pandas as pd
6
  from sklearn.metrics.pairwise import cosine_similarity
7
  from tqdm import tqdm
 
 
 
8
 
9
 
10
  from token_classifier import load_token_classifier, predict
11
  from model import Model
12
  from dataset import RetrievalDataset
13
 
14
-
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  batch_size = 512
17
 
18
-
19
- import zipfile
20
- import os
21
-
22
  def unzip_file(zip_path, extract_path):
23
  # Create the target directory if it doesn't exist
24
  os.makedirs(extract_path, exist_ok=True)
@@ -28,40 +26,39 @@ def unzip_file(zip_path, extract_path):
28
  # Extract all contents to the specified directory
29
  zip_ref.extractall(extract_path)
30
 
31
- # Example usage
32
  zip_path = "sample_evaluation.zip"
33
  extract_path = "sample_evaluation"
34
- unzip_file(zip_path, extract_path)
35
-
36
- from huggingface_hub import hf_hub_download
37
- hf_hub_download(repo_id="safinal/compositional-image-retrieval", filename="weights.pth", local_dir='.')
38
 
 
 
 
39
 
40
 
41
- def encode_database(model, df: pd.DataFrame) -> np.ndarray :
42
  """
43
  Process database images and generate embeddings.
44
-
45
- Args:
46
- df (pd. DataFrame ): DataFrame with column:
47
- - target_image: str, paths to database images
48
-
49
- Returns:
50
- np.ndarray: Embeddings array (num_images, embedding_dim)
51
  """
52
  model.eval()
53
  all_embeddings = []
 
54
  for i in tqdm(range(0, len(df), batch_size)):
55
- target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in df['target_image'][i:i+batch_size]]).to(device)
 
 
 
56
  with torch.no_grad():
57
- # target_imgs_embedding = model.encode_database_image(target_imgs)
58
  target_imgs_embedding = model.feature_extractor.encode_image(target_imgs)
59
  target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2)
60
  all_embeddings.append(target_imgs_embedding.detach().cpu().numpy())
 
 
 
61
  return np.concatenate(all_embeddings)
62
 
63
 
64
- # Load model and configurations
65
  def load_model():
66
  model = Model(model_name="ViTamin-L-384", pretrained=None)
67
  model.load("weights.pth")
@@ -70,7 +67,6 @@ def load_model():
70
 
71
 
72
  def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
73
-
74
  # Process query image
75
  query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
76
 
@@ -131,9 +127,12 @@ def process_single_query(model, query_image_path, query_text, database_embedding
131
 
132
  return most_similar_image_path
133
 
134
- # Initialize model and database
 
 
135
  model = load_model()
136
 
 
137
  test_dataset = RetrievalDataset(
138
  img_dir_path="sample_evaluation/images",
139
  annotations_file_path="sample_evaluation/data.csv",
@@ -142,19 +141,27 @@ test_dataset = RetrievalDataset(
142
  tokenizer=model.tokenizer
143
  )
144
 
145
- database_embeddings = encode_database(model, test_dataset.load_database()) # Using your existing function
 
 
 
 
146
 
147
  def interface_fn(selected_image, query_text):
 
 
 
148
  result_image_path = process_single_query(
149
  model,
150
  selected_image,
151
  query_text,
152
  database_embeddings,
153
- test_dataset.load_database()
154
  )
155
  return Image.open(result_image_path)
156
 
157
- # Create Gradio interface
 
158
  demo = gr.Interface(
159
  fn=interface_fn,
160
  inputs=[
@@ -170,7 +177,7 @@ demo = gr.Interface(
170
  ["sample_evaluation/images/455007.png", "Discard chair in the beginning, then proceed to bring car into play."],
171
  ["sample_evaluation/images/612311.png", "Get rid of train initially, and then follow up by including snowboard."]
172
  ],
173
- allow_flagging=False,
174
  cache_examples=False
175
  )
176
 
 
5
  import pandas as pd
6
  from sklearn.metrics.pairwise import cosine_similarity
7
  from tqdm import tqdm
8
+ import zipfile
9
+ import os
10
+ from huggingface_hub import hf_hub_download
11
 
12
 
13
  from token_classifier import load_token_classifier, predict
14
  from model import Model
15
  from dataset import RetrievalDataset
16
 
 
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  batch_size = 512
19
 
 
 
 
 
20
  def unzip_file(zip_path, extract_path):
21
  # Create the target directory if it doesn't exist
22
  os.makedirs(extract_path, exist_ok=True)
 
26
  # Extract all contents to the specified directory
27
  zip_ref.extractall(extract_path)
28
 
29
+ # Setup files
30
  zip_path = "sample_evaluation.zip"
31
  extract_path = "sample_evaluation"
32
+ if os.path.exists(zip_path): # Check exists to prevent errors if already unzipped
33
+ unzip_file(zip_path, extract_path)
 
 
34
 
35
+ # Download weights if not present
36
+ if not os.path.exists("weights.pth"):
37
+ hf_hub_download(repo_id="safinal/compositional-image-retrieval", filename="weights.pth", local_dir='.')
38
 
39
 
40
+ def encode_database(model, df: pd.DataFrame) -> np.ndarray:
41
  """
42
  Process database images and generate embeddings.
 
 
 
 
 
 
 
43
  """
44
  model.eval()
45
  all_embeddings = []
46
+ # Ensure batching handles empty or small datasets gracefully
47
  for i in tqdm(range(0, len(df), batch_size)):
48
+ batch_df = df['target_image'][i:i+batch_size]
49
+ if len(batch_df) == 0: continue
50
+
51
+ target_imgs = torch.stack([model.processor(Image.open(target_image_path)) for target_image_path in batch_df]).to(device)
52
  with torch.no_grad():
 
53
  target_imgs_embedding = model.feature_extractor.encode_image(target_imgs)
54
  target_imgs_embedding = torch.nn.functional.normalize(target_imgs_embedding, dim=1, p=2)
55
  all_embeddings.append(target_imgs_embedding.detach().cpu().numpy())
56
+
57
+ if not all_embeddings:
58
+ return np.array([])
59
  return np.concatenate(all_embeddings)
60
 
61
 
 
62
  def load_model():
63
  model = Model(model_name="ViTamin-L-384", pretrained=None)
64
  model.load("weights.pth")
 
67
 
68
 
69
  def process_single_query(model, query_image_path, query_text, database_embeddings, database_df):
 
70
  # Process query image
71
  query_img = model.processor(Image.open(query_image_path)).unsqueeze(0).to(device)
72
 
 
127
 
128
  return most_similar_image_path
129
 
130
+ # --- Initialization ---
131
+
132
+ print("Loading model...")
133
  model = load_model()
134
 
135
+ print("Loading dataset...")
136
  test_dataset = RetrievalDataset(
137
  img_dir_path="sample_evaluation/images",
138
  annotations_file_path="sample_evaluation/data.csv",
 
141
  tokenizer=model.tokenizer
142
  )
143
 
144
+ # Load database once globally to avoid reloading it on every user request
145
+ print("Encoding database...")
146
+ database_df = test_dataset.load_database()
147
+ database_embeddings = encode_database(model, database_df)
148
+
149
 
150
  def interface_fn(selected_image, query_text):
151
+ if selected_image is None:
152
+ return None
153
+
154
  result_image_path = process_single_query(
155
  model,
156
  selected_image,
157
  query_text,
158
  database_embeddings,
159
+ database_df # Pass the pre-loaded DataFrame
160
  )
161
  return Image.open(result_image_path)
162
 
163
+ # --- Gradio Interface ---
164
+
165
  demo = gr.Interface(
166
  fn=interface_fn,
167
  inputs=[
 
177
  ["sample_evaluation/images/455007.png", "Discard chair in the beginning, then proceed to bring car into play."],
178
  ["sample_evaluation/images/612311.png", "Get rid of train initially, and then follow up by including snowboard."]
179
  ],
180
+ flagging_mode="never",
181
  cache_examples=False
182
  )
183