File size: 6,327 Bytes
606d030
 
 
 
 
 
 
 
 
0c30b5c
 
 
 
 
48b6569
 
 
 
 
 
 
 
 
 
 
606d030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48b6569
 
 
 
 
 
 
 
606d030
 
 
48b6569
606d030
 
48b6569
606d030
 
48b6569
606d030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c30b5c
606d030
 
 
 
 
 
 
 
 
 
 
0c30b5c
606d030
 
 
 
 
 
 
 
 
 
48b6569
 
606d030
48b6569
606d030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import open_clip
from mobileclip.modules.common.mobileone import reparameterize_model
from PIL import Image
import requests
from io import BytesIO
import logging
try:
    import numpy as np
    print("✅ NumPy imported successfully:", np.__version__)
except ImportError as e:
    print("❌ NumPy failed to import:", str(e))
import os

# Set cache directories
os.environ['HF_HOME'] = '/app/.cache'
os.environ['TORCH_HOME'] = '/app/.cache/torch'
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache/transformers'

# Create cache directories if they don't exist
os.makedirs('/app/.cache', exist_ok=True)
os.makedirs('/app/.cache/torch', exist_ok=True)
os.makedirs('/app/.cache/transformers', exist_ok=True)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="MobileCLIP API",
    description="API for MobileCLIP image and text embeddings",
    version="1.0.0"
)

# Global variables for model
model = None
preprocess = None
tokenizer = None

class TextRequest(BaseModel):
    text: str

class ImageRequest(BaseModel):
    image_url: str

class SimilarityRequest(BaseModel):
    image_url: str
    text: str

class EmbeddingResponse(BaseModel):
    embedding: list

class SimilarityResponse(BaseModel):
    similarity: float

def load_model():
    """Load and initialize the MobileCLIP model"""
    global model, preprocess, tokenizer
    
    try:
        logger.info("📥 Downloading MobileCLIP-S2 model...")
        # Explicitly set cache directory
        model, _, preprocess = open_clip.create_model_and_transforms(
            'MobileCLIP-S2', 
            pretrained='datacompdr',
            cache_dir='/app/.cache'
        )
        logger.info("🔧 Loading tokenizer...")
        tokenizer = open_clip.get_tokenizer('MobileCLIP-S2')
        
        # Reparameterize for inference
        logger.info("⚡ Reparameterizing model for inference...")
        model.eval()
        model = reparameterize_model(model)
        logger.info("✅ Model loaded and optimized successfully!")
        
    except Exception as e:
        logger.error(f"❌ Failed to load model: {str(e)}")
        raise e

def download_image(url: str) -> Image.Image:
    """Download image from URL"""
    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        image = Image.open(BytesIO(response.content))
        return image.convert('RGB')
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}")

def get_image_embedding(image: Image.Image):
    """Get embedding for an image"""
    try:
        image_tensor = preprocess(image).unsqueeze(0)
        with torch.no_grad():
            image_features = model.encode_image(image_tensor)
            # Normalize the embedding
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return image_features.squeeze().cpu()
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to process image: {str(e)}")

def get_text_embedding(text: str):
    """Get embedding for text"""
    try:
        text_tokens = tokenizer([text])
        with torch.no_grad():
            text_features = model.encode_text(text_tokens)
            # Normalize the embedding
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        return text_features.squeeze().cpu()
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to process text: {str(e)}")

def calculate_similarity(embedding1: np.ndarray, embedding2: np.ndarray) -> float:
    """Calculate cosine similarity between two embeddings"""
    return float(np.dot(embedding1, embedding2))

@app.on_event("startup")
async def startup_event():
    """Load model on startup"""
    logger.info("🚀 Starting MobileCLIP API...")
    logger.info("📦 Loading model - this may take 2-5 minutes...")
    load_model()
    logger.info("✅ Model loaded successfully! API is ready.")

@app.get("/")
async def root():
    """Health check endpoint"""
    return {"message": "MobileCLIP API is running!", "status": "healthy"}

@app.post("/image-embedding", response_model=EmbeddingResponse)
async def image_embedding(request: ImageRequest):
    """Get embedding for an image given its URL"""
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    try:
        image = download_image(request.image_url)
        embedding = get_image_embedding(image)
        return EmbeddingResponse(embedding=embedding.tolist())
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error in image_embedding: {str(e)}")
        raise HTTPException(status_code=500, detail="Internal server error")

@app.post("/text-embedding", response_model=EmbeddingResponse)
async def text_embedding(request: TextRequest):
    """Get embedding for text"""
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    try:
        embedding = get_text_embedding(request.text)
        return EmbeddingResponse(embedding=embedding.tolist())
    except Exception as e:
        logger.error(f"Error in text_embedding: {str(e)}")
        raise HTTPException(status_code=500, detail="Internal server error")

@app.post("/similarity", response_model=SimilarityResponse)
async def similarity(request: SimilarityRequest):
    """Calculate similarity between image and text"""
    if model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    
    try:
        image = download_image(request.image_url)
        image_embedding = get_image_embedding(image)
        text_embedding = get_text_embedding(request.text)
        
        similarity_score = calculate_similarity(image_embedding, text_embedding)
        return SimilarityResponse(similarity=similarity_score)
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error in similarity: {str(e)}")
        raise HTTPException(status_code=500, detail="Internal server error")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)