101Frost commited on
Commit
606d030
·
verified ·
1 Parent(s): 60af893

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import open_clip
5
+ from mobileclip.modules.common.mobileone import reparameterize_model
6
+ from PIL import Image
7
+ import requests
8
+ from io import BytesIO
9
+ import logging
10
+ import numpy as np
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ app = FastAPI(
17
+ title="MobileCLIP API",
18
+ description="API for MobileCLIP image and text embeddings",
19
+ version="1.0.0"
20
+ )
21
+
22
+ # Global variables for model
23
+ model = None
24
+ preprocess = None
25
+ tokenizer = None
26
+
27
+ class TextRequest(BaseModel):
28
+ text: str
29
+
30
+ class ImageRequest(BaseModel):
31
+ image_url: str
32
+
33
+ class SimilarityRequest(BaseModel):
34
+ image_url: str
35
+ text: str
36
+
37
+ class EmbeddingResponse(BaseModel):
38
+ embedding: list
39
+
40
+ class SimilarityResponse(BaseModel):
41
+ similarity: float
42
+
43
+ def load_model():
44
+ """Load and initialize the MobileCLIP model"""
45
+ global model, preprocess, tokenizer
46
+
47
+ try:
48
+ logger.info("Loading MobileCLIP model...")
49
+ model, _, preprocess = open_clip.create_model_and_transforms('MobileCLIP-S2', pretrained='datacompdr')
50
+ tokenizer = open_clip.get_tokenizer('MobileCLIP-S2')
51
+
52
+ # Reparameterize for inference
53
+ model.eval()
54
+ model = reparameterize_model(model)
55
+ logger.info("Model loaded successfully!")
56
+
57
+ except Exception as e:
58
+ logger.error(f"Failed to load model: {str(e)}")
59
+ raise e
60
+
61
+ def download_image(url: str) -> Image.Image:
62
+ """Download image from URL"""
63
+ try:
64
+ response = requests.get(url, timeout=10)
65
+ response.raise_for_status()
66
+ image = Image.open(BytesIO(response.content))
67
+ return image.convert('RGB')
68
+ except Exception as e:
69
+ raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}")
70
+
71
+ def get_image_embedding(image: Image.Image):
72
+ """Get embedding for an image"""
73
+ try:
74
+ image_tensor = preprocess(image).unsqueeze(0)
75
+ with torch.no_grad():
76
+ image_features = model.encode_image(image_tensor)
77
+ # Normalize the embedding
78
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
79
+ return image_features.squeeze().cpu().numpy()
80
+ except Exception as e:
81
+ raise HTTPException(status_code=500, detail=f"Failed to process image: {str(e)}")
82
+
83
+ def get_text_embedding(text: str):
84
+ """Get embedding for text"""
85
+ try:
86
+ text_tokens = tokenizer([text])
87
+ with torch.no_grad():
88
+ text_features = model.encode_text(text_tokens)
89
+ # Normalize the embedding
90
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
91
+ return text_features.squeeze().cpu().numpy()
92
+ except Exception as e:
93
+ raise HTTPException(status_code=500, detail=f"Failed to process text: {str(e)}")
94
+
95
+ def calculate_similarity(embedding1: np.ndarray, embedding2: np.ndarray) -> float:
96
+ """Calculate cosine similarity between two embeddings"""
97
+ return float(np.dot(embedding1, embedding2))
98
+
99
+ @app.on_event("startup")
100
+ async def startup_event():
101
+ """Load model on startup"""
102
+ load_model()
103
+
104
+ @app.get("/")
105
+ async def root():
106
+ """Health check endpoint"""
107
+ return {"message": "MobileCLIP API is running!", "status": "healthy"}
108
+
109
+ @app.post("/image-embedding", response_model=EmbeddingResponse)
110
+ async def image_embedding(request: ImageRequest):
111
+ """Get embedding for an image given its URL"""
112
+ if model is None:
113
+ raise HTTPException(status_code=503, detail="Model not loaded")
114
+
115
+ try:
116
+ image = download_image(request.image_url)
117
+ embedding = get_image_embedding(image)
118
+ return EmbeddingResponse(embedding=embedding.tolist())
119
+ except HTTPException:
120
+ raise
121
+ except Exception as e:
122
+ logger.error(f"Error in image_embedding: {str(e)}")
123
+ raise HTTPException(status_code=500, detail="Internal server error")
124
+
125
+ @app.post("/text-embedding", response_model=EmbeddingResponse)
126
+ async def text_embedding(request: TextRequest):
127
+ """Get embedding for text"""
128
+ if model is None:
129
+ raise HTTPException(status_code=503, detail="Model not loaded")
130
+
131
+ try:
132
+ embedding = get_text_embedding(request.text)
133
+ return EmbeddingResponse(embedding=embedding.tolist())
134
+ except Exception as e:
135
+ logger.error(f"Error in text_embedding: {str(e)}")
136
+ raise HTTPException(status_code=500, detail="Internal server error")
137
+
138
+ @app.post("/similarity", response_model=SimilarityResponse)
139
+ async def similarity(request: SimilarityRequest):
140
+ """Calculate similarity between image and text"""
141
+ if model is None:
142
+ raise HTTPException(status_code=503, detail="Model not loaded")
143
+
144
+ try:
145
+ image = download_image(request.image_url)
146
+ image_embedding = get_image_embedding(image)
147
+ text_embedding = get_text_embedding(request.text)
148
+
149
+ similarity_score = calculate_similarity(image_embedding, text_embedding)
150
+ return SimilarityResponse(similarity=similarity_score)
151
+ except HTTPException:
152
+ raise
153
+ except Exception as e:
154
+ logger.error(f"Error in similarity: {str(e)}")
155
+ raise HTTPException(status_code=500, detail="Internal server error")
156
+
157
+ if __name__ == "__main__":
158
+ import uvicorn
159
+ uvicorn.run(app, host="0.0.0.0", port=7860)