jichao commited on
Commit
4a47660
·
1 Parent(s): b5fe4f8

first model added

Browse files
Files changed (2) hide show
  1. app.py +195 -4
  2. requirements.txt +7 -0
app.py CHANGED
@@ -1,7 +1,198 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import timm
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ import os
8
 
9
+ # Define available models
10
+ MODELS = {
11
+ "mars-vit-b-0217": {
12
+ "path": os.path.join('models', 'checkpint-300.pth'),
13
+ "architecture": 'vit_base_patch16_224',
14
+ "img_size": 224,
15
+ "in_chans": 1,
16
+ "mean": [0.5],
17
+ "std": [0.25]
18
+ }
19
+ # Add more models here in the future
20
+ }
21
 
22
+ # Default model
23
+ DEFAULT_MODEL = "mars-vit-b-0217"
24
+
25
+ # Model cache to avoid reloading
26
+ loaded_models = {}
27
+
28
+ def get_transform(model_name):
29
+ """Get the appropriate transform for the model"""
30
+ model_config = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
31
+ return transforms.Compose([
32
+ transforms.Resize((model_config["img_size"], model_config["img_size"])),
33
+ transforms.Grayscale(), # Convert to grayscale (1 channel)
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=model_config["mean"], std=model_config["std"])
36
+ ])
37
+
38
+ def load_model(model_name):
39
+ """Load the specified model"""
40
+ if model_name in loaded_models:
41
+ return loaded_models[model_name]
42
+
43
+ model_config = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
44
+
45
+ model = timm.create_model(
46
+ model_config["architecture"],
47
+ img_size=model_config["img_size"],
48
+ in_chans=model_config["in_chans"],
49
+ num_classes=0, # no head
50
+ global_pool='', # no pooling
51
+ )
52
+
53
+ # Load converted weights
54
+ checkpoint = torch.load(model_config["path"], map_location='cpu', weights_only=False)
55
+ msg = model.load_state_dict(checkpoint['state_dict'], strict=False)
56
+ print(f"Loaded {model_name} weights with message: {msg}")
57
+
58
+ model.eval() # Set model to evaluation mode
59
+ loaded_models[model_name] = model
60
+ return model
61
+
62
+ # Load the default model at startup
63
+ default_model = load_model(DEFAULT_MODEL)
64
+
65
+ def get_embedding(image, model_name=DEFAULT_MODEL):
66
+ """Calculate embedding for an image using the specified model"""
67
+ if image is None:
68
+ return None, "No image provided"
69
+
70
+ try:
71
+ # Get the model
72
+ model = load_model(model_name)
73
+
74
+ # Convert to PIL Image if it's not already
75
+ if not isinstance(image, Image.Image):
76
+ image = Image.fromarray(image)
77
+
78
+ # Apply transformations
79
+ transform = get_transform(model_name)
80
+ img_tensor = transform(image).unsqueeze(0) # Add batch dimension
81
+
82
+ # Get embedding
83
+ with torch.no_grad():
84
+ embedding = model(img_tensor)
85
+
86
+ # Convert to numpy and normalize
87
+ embedding_np = embedding.squeeze().cpu().numpy()
88
+
89
+ # Normalize embedding to unit length
90
+ embedding_norm = embedding_np / np.linalg.norm(embedding_np)
91
+
92
+ return embedding_norm, f"Embedding calculated successfully using {model_name}"
93
+ except Exception as e:
94
+ return None, f"Error calculating embedding: {str(e)}"
95
+
96
+ def process_image(image, model_name=DEFAULT_MODEL):
97
+ """Process image and return embedding with visualization"""
98
+ embedding, message = get_embedding(image, model_name)
99
+
100
+ if embedding is None:
101
+ return None, None, message, None
102
+
103
+ # Create a simple visualization of the embedding (first 100 values)
104
+ import matplotlib.pyplot as plt
105
+ plt.figure(figsize=(10, 4))
106
+ plt.bar(range(min(100, len(embedding))), embedding[:100])
107
+ plt.title(f"Embedding Visualization ({model_name}, first 100 dimensions)")
108
+ plt.xlabel("Dimension")
109
+ plt.ylabel("Value")
110
+
111
+ # Save the plot to a temporary file
112
+ vis_path = "embedding_vis.png"
113
+ plt.savefig(vis_path)
114
+ plt.close()
115
+
116
+ # Return the processed image, embedding visualization, and message
117
+ return image, vis_path, message, embedding.tolist()
118
+
119
+ # Define the Gradio interface
120
+ with gr.Blocks() as demo:
121
+ gr.Markdown("# Image Embedding Calculator")
122
+ gr.Markdown("Upload an image to calculate its embedding vector using a Vision Transformer model")
123
+
124
+ with gr.Tab("Interactive Demo"):
125
+ with gr.Row():
126
+ with gr.Column():
127
+ input_image = gr.Image(type="pil", label="Input Image")
128
+ model_dropdown = gr.Dropdown(
129
+ choices=list(MODELS.keys()),
130
+ value=DEFAULT_MODEL,
131
+ label="Model"
132
+ )
133
+ submit_btn = gr.Button("Calculate Embedding")
134
+
135
+ with gr.Column():
136
+ output_image = gr.Image(type="pil", label="Processed Image")
137
+ output_vis = gr.Image(type="filepath", label="Embedding Visualization")
138
+ output_message = gr.Textbox(label="Status")
139
+ output_embedding = gr.JSON(label="Embedding Vector")
140
+
141
+ submit_btn.click(
142
+ fn=process_image,
143
+ inputs=[input_image, model_dropdown],
144
+ outputs=[output_image, output_vis, output_message, output_embedding]
145
+ )
146
+
147
+ with gr.Tab("API Documentation"):
148
+ gr.Markdown("""
149
+ ## API Usage
150
+
151
+ This application provides an API endpoint for calculating image embeddings.
152
+
153
+ ### Endpoint: `/api/predict`
154
+
155
+ **Method**: POST
156
+
157
+ **Input**:
158
+ - `image`: An image file
159
+ - `model_name`: (Optional) Name of the model to use (default: "mars-vit-b-0217")
160
+
161
+ **Output**:
162
+ ```json
163
+ {
164
+ "embedding": [...], // The embedding vector
165
+ "message": "Status message",
166
+ "model_name": "mars-vit-b-0217" // The model used
167
+ }
168
+ ```
169
+
170
+ ### Example using Python requests:
171
+ ```python
172
+ import requests
173
+
174
+ response = requests.post(
175
+ "https://yourusername-embedding-helper.hf.space/api/predict",
176
+ files={"image": open("your_image.jpg", "rb")},
177
+ data={"model_name": "mars-vit-b-0217"}
178
+ )
179
+
180
+ result = response.json()
181
+ embedding = result["embedding"]
182
+ ```
183
+ """)
184
+
185
+ # Define API endpoint function
186
+ def api_predict(image, model_name=DEFAULT_MODEL):
187
+ embedding, message = get_embedding(image, model_name)
188
+ if embedding is None:
189
+ return {"embedding": None, "message": message, "model_name": model_name}
190
+ return {"embedding": embedding.tolist(), "message": message, "model_name": model_name}
191
+
192
+ # Mount the API endpoint
193
+ demo.queue()
194
+ demo = gr.mount_gradio_app(app=demo, blocks=demo, path="/")
195
+
196
+ # Launch the app
197
+ if __name__ == "__main__":
198
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.10.0
3
+ timm>=1.0.0
4
+ gradio
5
+ numpy<2.0.0
6
+ Pillow>=8.3.1
7
+ matplotlib>=3.5.0