jichao commited on
Commit
87188fa
·
1 Parent(s): 0e5bea3
Files changed (2) hide show
  1. app.py +203 -185
  2. requirements.txt +4 -7
app.py CHANGED
@@ -1,209 +1,227 @@
1
  import gradio as gr
2
  import torch
3
  import timm
4
- import numpy as np
5
  from PIL import Image
6
- import torchvision.transforms as transforms
7
  import os
8
 
9
- # Define available models
10
- MODELS = {
11
- "mars-vit-b-0217": {
12
- "path": os.path.join('models', 'checkpint-300.pth'),
13
- "architecture": 'vit_base_patch16_224',
14
- "img_size": 224,
15
  "in_chans": 1,
16
- "mean": [0.5],
17
- "std": [0.25]
18
- }
19
- # Add more models here in the future
 
 
 
 
 
20
  }
21
 
22
- # Default model
23
- DEFAULT_MODEL = "mars-vit-b-0217"
24
 
25
- # Model cache to avoid reloading
26
- loaded_models = {}
 
 
 
27
 
28
- def get_transform(model_name):
29
- """Get the appropriate transform for the model"""
30
- model_config = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
31
- return transforms.Compose([
32
- transforms.Resize((model_config["img_size"], model_config["img_size"])),
33
- transforms.Grayscale(), # Convert to grayscale (1 channel)
34
- transforms.ToTensor(),
35
- transforms.Normalize(mean=model_config["mean"], std=model_config["std"])
36
- ])
37
 
38
- def load_model(model_name):
39
- """Load the specified model"""
40
- if model_name in loaded_models:
41
- return loaded_models[model_name]
42
-
43
- model_config = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
44
-
45
  model = timm.create_model(
46
- model_config["architecture"],
47
- img_size=model_config["img_size"],
48
- in_chans=model_config["in_chans"],
49
- num_classes=0, # no head
50
- global_pool='', # no pooling
 
51
  )
52
 
53
- # Load converted weights
54
- checkpoint = torch.load(model_config["path"], map_location='cpu', weights_only=False)
55
- msg = model.load_state_dict(checkpoint['state_dict'], strict=False)
56
- print(f"Loaded {model_name} weights with message: {msg}")
57
-
58
- model.eval() # Set model to evaluation mode
59
- loaded_models[model_name] = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  return model
61
 
62
- # Load the default model at startup
63
- default_model = load_model(DEFAULT_MODEL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- def get_embedding(image, model_name=DEFAULT_MODEL):
66
- """Calculate embedding for an image using the specified model"""
67
- if image is None:
68
- return None, "No image provided"
69
-
70
  try:
71
- # Get the model
72
- model = load_model(model_name)
73
-
74
- # Convert to PIL Image if it's not already
75
- if not isinstance(image, Image.Image):
76
- image = Image.fromarray(image)
77
-
78
- # Apply transformations
79
- transform = get_transform(model_name)
80
- img_tensor = transform(image).unsqueeze(0) # Add batch dimension
81
-
82
- # Get embedding
83
  with torch.no_grad():
84
- embedding = model(img_tensor)
85
-
86
- # Convert to numpy and normalize
87
- embedding_np = embedding.squeeze().cpu().numpy()
88
-
89
- # Normalize embedding to unit length
90
- embedding_norm = embedding_np / np.linalg.norm(embedding_np)
91
-
92
- return embedding_norm, f"Embedding calculated successfully using {model_name}"
 
 
 
 
 
 
 
 
 
 
 
 
93
  except Exception as e:
94
- return None, f"Error calculating embedding: {str(e)}"
95
-
96
- def process_image(image, model_name=DEFAULT_MODEL):
97
- """Process image and return embedding with visualization"""
98
- embedding, message = get_embedding(image, model_name)
99
-
100
- if embedding is None:
101
- return None, None, message, None
102
-
103
- # Create a simple visualization of the embedding (first 100 values)
104
- import matplotlib.pyplot as plt
105
- plt.figure(figsize=(10, 4))
106
- plt.bar(range(min(100, len(embedding))), embedding[:100])
107
- plt.title(f"Embedding Visualization ({model_name}, first 100 dimensions)")
108
- plt.xlabel("Dimension")
109
- plt.ylabel("Value")
110
-
111
- # Save the plot to a temporary file
112
- vis_path = "embedding_vis.png"
113
- plt.savefig(vis_path)
114
- plt.close()
115
-
116
- # Return the processed image, embedding visualization, and message
117
- return image, vis_path, message, embedding.tolist()
118
-
119
- # Define API endpoint function
120
- def api_predict(image, model_name=DEFAULT_MODEL):
121
- embedding, message = get_embedding(image, model_name)
122
- if embedding is None:
123
- return {"embedding": None, "message": message, "model_name": model_name}
124
- return {"embedding": embedding.tolist(), "message": message, "model_name": model_name}
125
-
126
- # Set up the Gradio interface with API
127
- demo = gr.Blocks()
128
-
129
- with demo:
130
- gr.Markdown("# Image Embedding Calculator")
131
- gr.Markdown("Upload an image to calculate its embedding vector using a Vision Transformer model")
132
-
133
- with gr.Tab("Interactive Demo"):
134
- with gr.Row():
135
- with gr.Column():
136
- input_image = gr.Image(type="pil", label="Input Image")
137
- model_dropdown = gr.Dropdown(
138
- choices=list(MODELS.keys()),
139
- value=DEFAULT_MODEL,
140
- label="Model"
141
- )
142
- submit_btn = gr.Button("Calculate Embedding")
143
-
144
- with gr.Column():
145
- output_image = gr.Image(type="pil", label="Processed Image")
146
- output_vis = gr.Image(type="filepath", label="Embedding Visualization")
147
- output_message = gr.Textbox(label="Status")
148
- output_embedding = gr.JSON(label="Embedding Vector")
149
-
150
- submit_btn.click(
151
- fn=process_image,
152
- inputs=[input_image, model_dropdown],
153
- outputs=[output_image, output_vis, output_message, output_embedding]
154
- )
155
-
156
- with gr.Tab("API Documentation"):
157
- gr.Markdown("""
158
- ## API Usage
159
-
160
- This application provides an API endpoint for calculating image embeddings.
161
-
162
- ### Endpoint: `/api/predict`
163
-
164
- **Method**: POST
165
-
166
- **Input**:
167
- - `image`: An image file
168
- - `model_name`: (Optional) Name of the model to use (default: "mars-vit-b-0217")
169
-
170
- **Output**:
171
- ```json
172
- {
173
- "embedding": [...], // The embedding vector
174
- "message": "Status message",
175
- "model_name": "mars-vit-b-0217" // The model used
176
  }
177
- ```
178
-
179
- ### Example using Python requests:
180
- ```python
181
- import requests
182
-
183
- response = requests.post(
184
- "https://yourusername-embedding-helper.hf.space/api/predict",
185
- files={"image": open("your_image.jpg", "rb")},
186
- data={"model_name": "mars-vit-b-0217"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
-
189
- result = response.json()
190
- embedding = result["embedding"]
191
- ```
192
- """)
193
-
194
- # Create the API endpoint
195
- api_predict_interface = gr.Interface(
196
- fn=api_predict,
197
- inputs=[
198
- gr.Image(type="pil"),
199
- gr.Textbox(default=DEFAULT_MODEL, label="Model Name")
200
- ],
201
- outputs=gr.JSON(),
202
- title="Image Embedding API",
203
- description="API for calculating image embeddings",
204
- allow_flagging="never"
205
- )
206
-
207
- # Launch the app with the API
208
  if __name__ == "__main__":
209
- demo.launch(share=False)
 
1
  import gradio as gr
2
  import torch
3
  import timm
4
+ from torchvision import transforms
5
  from PIL import Image
6
+ import numpy as np
7
  import os
8
 
9
+ # --- Model Configuration ---
10
+ DEFAULT_MODEL_NAME = "mars-ctx-vitb-0217"
11
+ MODEL_CONFIGS = {
12
+ "mars-ctx-vitb-0217": {
13
+ "path": "models/checkpoint-300.pth",
14
+ "timm_id": "vit_base_patch16_224",
15
  "in_chans": 1,
16
+ "description": "ViT-Base/16 (Grayscale Input)"
17
+ },
18
+ # --- Add more model configurations here ---
19
+ # "another_model_name": {
20
+ # "path": "models/another_checkpoint.pth",
21
+ # "timm_id": "vit_small_patch16_224",
22
+ # "in_chans": 3, # Example: RGB model
23
+ # "description": "ViT-Small/16 (RGB Input)"
24
+ # },
25
  }
26
 
27
+ # Global dictionary to store loaded models
28
+ LOADED_MODELS = {}
29
 
30
+ # --- Model Loading Function ---
31
+ def load_model(model_name: str):
32
+ """Loads a model based on its name from MODEL_CONFIGS."""
33
+ if model_name not in MODEL_CONFIGS:
34
+ raise ValueError(f"Unknown model name: {model_name}")
35
 
36
+ config = MODEL_CONFIGS[model_name]
37
+ model_path = config["path"]
38
+ timm_id = config["timm_id"]
39
+ in_chans = config.get("in_chans", 3) # Default to 3 channels if not specified
40
+
41
+ print(f"Loading model: {model_name} ({timm_id}) from {model_path}")
 
 
 
42
 
 
 
 
 
 
 
 
43
  model = timm.create_model(
44
+ timm_id,
45
+ img_size=224,
46
+ in_chans=in_chans,
47
+ num_classes=0, # No classification head
48
+ global_pool='', # No pooling - we want the CLS token feature
49
+ pretrained=False # Don't load timm pretrained weights, we use our checkpoint
50
  )
51
 
52
+ # Ensure the directory exists before checking the file
53
+ model_dir = os.path.dirname(model_path)
54
+ if model_dir and not os.path.exists(model_dir):
55
+ print(f"Creating directory: {model_dir}")
56
+ os.makedirs(model_dir, exist_ok=True)
57
+
58
+ if not os.path.exists(model_path):
59
+ print(f"Warning: Model checkpoint not found at {model_path}. Using random weights for {model_name}.")
60
+ model.eval() # Still set to eval mode
61
+ return model # Return untrained model if checkpoint missing
62
+
63
+ try:
64
+ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
65
+ state_dict = checkpoint.get('state_dict', checkpoint)
66
+ # Handle potential mismatches if loading weights from a different architecture/head
67
+ msg = model.load_state_dict(state_dict, strict=False)
68
+ print(f"Loaded weights for {model_name} from {model_path}. Load message: {msg}")
69
+ if msg.missing_keys or msg.unexpected_keys:
70
+ print(f"Note: There were missing or unexpected keys during weight loading for {model_name}. Check compatibility.")
71
+
72
+ except Exception as e:
73
+ print(f"Error loading checkpoint for {model_name} from {model_path}: {e}")
74
+ print(f"Proceeding with randomly initialized weights for {model_name}.")
75
+
76
+ model.eval() # Set model to evaluation mode
77
  return model
78
 
79
+ # --- Pre-load Default Model --- (Or load on demand in get_embedding)
80
+ try:
81
+ print(f"Pre-loading default model: {DEFAULT_MODEL_NAME}...")
82
+ LOADED_MODELS[DEFAULT_MODEL_NAME] = load_model(DEFAULT_MODEL_NAME)
83
+ print(f"Default model {DEFAULT_MODEL_NAME} loaded successfully.")
84
+ except Exception as e:
85
+ print(f"ERROR: Failed to pre-load default model {DEFAULT_MODEL_NAME}: {e}")
86
+ # Decide how to handle this - exit, or let Gradio fail later?
87
+ # For now, we'll print the error and continue; the app might fail if the default model is needed.
88
+
89
+ # --- Image Preprocessing --- (Now depends on model input channels)
90
+ def get_preprocess(model_name: str):
91
+ """Returns the appropriate preprocessing transform for the model."""
92
+ config = MODEL_CONFIGS.get(model_name, MODEL_CONFIGS[DEFAULT_MODEL_NAME]) # Fallback to default
93
+ in_chans = config.get('in_chans', 3)
94
+ mean = [0.5] * in_chans
95
+ std = [0.25] * in_chans # Assuming same normalization for now
96
+
97
+ transforms_list = [
98
+ transforms.Resize((224, 224)),
99
+ ]
100
+ if in_chans == 1:
101
+ transforms_list.append(transforms.Grayscale(num_output_channels=1))
102
+
103
+ transforms_list.extend([
104
+ transforms.ToTensor(),
105
+ transforms.Normalize(mean=mean, std=std),
106
+ ])
107
+ return transforms.Compose(transforms_list)
108
+
109
+ # --- Embedding Function ---
110
+ def get_embedding(image_pil: Image.Image, model_name: str) -> dict:
111
+ """Preprocesses an image, extracts the CLS token embedding for the selected model,
112
+ normalizes it, and returns a dictionary containing model info, embedding data (or null),
113
+ and a status message."""
114
+ if image_pil is None:
115
+ return {
116
+ "model_name": model_name,
117
+ "data": None,
118
+ "message": "Error: Please upload an image."
119
+ }
120
+ if model_name not in MODEL_CONFIGS:
121
+ return {
122
+ "model_name": model_name,
123
+ "data": None,
124
+ "message": f"Error: Unknown model name '{model_name}'."
125
+ }
126
+
127
+ # --- Get the model (load if not already loaded) ---
128
+ if model_name not in LOADED_MODELS:
129
+ try:
130
+ print(f"Loading model {model_name} on demand...")
131
+ LOADED_MODELS[model_name] = load_model(model_name)
132
+ print(f"Model {model_name} loaded successfully.")
133
+ except Exception as e:
134
+ error_msg = f"Error loading model '{model_name}'. Check logs."
135
+ print(f"Error loading model {model_name}: {e}")
136
+ return {
137
+ "model_name": model_name,
138
+ "data": None,
139
+ "message": error_msg
140
+ }
141
+
142
+ selected_model = LOADED_MODELS[model_name]
143
+ preprocess = get_preprocess(model_name)
144
 
 
 
 
 
 
145
  try:
146
+ # Preprocess based on the selected model's requirements
147
+ img_tensor = preprocess(image_pil).unsqueeze(0) # Add batch dimension [1, C, H, W]
148
+
 
 
 
 
 
 
 
 
 
149
  with torch.no_grad():
150
+ features = selected_model.forward_features(img_tensor)
151
+ if isinstance(features, tuple):
152
+ features = features[0]
153
+ if len(features.shape) == 3:
154
+ cls_embedding = features[:, 0]
155
+ else:
156
+ print(f"Warning: Unexpected feature shape for {model_name}: {features.shape}. Attempting to use as is.")
157
+ cls_embedding = features
158
+
159
+ normalized_embedding = torch.nn.functional.normalize(cls_embedding, p=2, dim=1)
160
+
161
+ embedding_list = normalized_embedding.squeeze().cpu().numpy().tolist()
162
+ if not isinstance(embedding_list, list):
163
+ embedding_list = [embedding_list] # Ensure it's always a list
164
+
165
+ return {
166
+ "model_name": model_name,
167
+ "data": embedding_list,
168
+ "message": "Success"
169
+ }
170
+
171
  except Exception as e:
172
+ error_msg = f"Error processing image with model '{model_name}'. Check logs for details."
173
+ print(f"Error processing image with model {model_name}: {e}")
174
+ import traceback
175
+ traceback.print_exc() # Print detailed traceback to logs
176
+ return {
177
+ "model_name": model_name,
178
+ "data": None,
179
+ "message": error_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  }
181
+
182
+ # --- Gradio Interface ---
183
+ EXAMPLE_DIR = "examples"
184
+ EXAMPLE_IMAGE = os.path.join(EXAMPLE_DIR, "sample_image.png")
185
+ os.makedirs(EXAMPLE_DIR, exist_ok=True)
186
+ examples = [[EXAMPLE_IMAGE, DEFAULT_MODEL_NAME]] if os.path.exists(EXAMPLE_IMAGE) else None
187
+
188
+ # Get list of model names for dropdown
189
+ model_choices = list(MODEL_CONFIGS.keys())
190
+
191
+ with gr.Blocks() as iface:
192
+ gr.Markdown("## Image Embedding Calculator")
193
+ gr.Markdown("Upload an image and select a model to calculate its normalized CLS token embedding.")
194
+
195
+ with gr.Row():
196
+ with gr.Column(scale=1):
197
+ input_image = gr.Image(type="pil", label="Upload Image")
198
+ model_selector = gr.Dropdown(
199
+ choices=model_choices,
200
+ value=DEFAULT_MODEL_NAME,
201
+ label="Select Model"
202
+ )
203
+ submit_btn = gr.Button("Calculate Embedding")
204
+ with gr.Column(scale=2):
205
+ # Change output component to JSON
206
+ output_embedding = gr.JSON(label="Output (Embedding & Info)")
207
+
208
+ if examples:
209
+ gr.Examples(
210
+ examples=examples,
211
+ inputs=[input_image, model_selector],
212
+ outputs=output_embedding,
213
+ fn=get_embedding,
214
+ cache_examples=False # Recompute if necessary, maybe True if inputs are static
215
  )
216
+
217
+ # Connect the button click to the function
218
+ submit_btn.click(
219
+ fn=get_embedding,
220
+ inputs=[input_image, model_selector],
221
+ outputs=output_embedding,
222
+ api_name="predict" # Expose API endpoint
223
+ )
224
+
225
+ # --- Launch the App ---
 
 
 
 
 
 
 
 
 
 
226
  if __name__ == "__main__":
227
+ iface.launch(server_name="0.0.0.0")
requirements.txt CHANGED
@@ -1,7 +1,4 @@
1
- torch>=2.0.0
2
- torchvision>=0.10.0
3
- timm>=1.0.0
4
- gradio
5
- numpy<2.0.0
6
- Pillow>=8.3.1
7
- matplotlib>=3.5.0
 
1
+ torch
2
+ timm
3
+ torchvision
4
+ Pillow