Spaces:
Sleeping
Sleeping
File size: 6,327 Bytes
606d030 0c30b5c 48b6569 606d030 48b6569 606d030 48b6569 606d030 48b6569 606d030 48b6569 606d030 0c30b5c 606d030 0c30b5c 606d030 48b6569 606d030 48b6569 606d030 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import open_clip
from mobileclip.modules.common.mobileone import reparameterize_model
from PIL import Image
import requests
from io import BytesIO
import logging
try:
import numpy as np
print("✅ NumPy imported successfully:", np.__version__)
except ImportError as e:
print("❌ NumPy failed to import:", str(e))
import os
# Set cache directories
os.environ['HF_HOME'] = '/app/.cache'
os.environ['TORCH_HOME'] = '/app/.cache/torch'
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache/transformers'
# Create cache directories if they don't exist
os.makedirs('/app/.cache', exist_ok=True)
os.makedirs('/app/.cache/torch', exist_ok=True)
os.makedirs('/app/.cache/transformers', exist_ok=True)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="MobileCLIP API",
description="API for MobileCLIP image and text embeddings",
version="1.0.0"
)
# Global variables for model
model = None
preprocess = None
tokenizer = None
class TextRequest(BaseModel):
text: str
class ImageRequest(BaseModel):
image_url: str
class SimilarityRequest(BaseModel):
image_url: str
text: str
class EmbeddingResponse(BaseModel):
embedding: list
class SimilarityResponse(BaseModel):
similarity: float
def load_model():
"""Load and initialize the MobileCLIP model"""
global model, preprocess, tokenizer
try:
logger.info("📥 Downloading MobileCLIP-S2 model...")
# Explicitly set cache directory
model, _, preprocess = open_clip.create_model_and_transforms(
'MobileCLIP-S2',
pretrained='datacompdr',
cache_dir='/app/.cache'
)
logger.info("🔧 Loading tokenizer...")
tokenizer = open_clip.get_tokenizer('MobileCLIP-S2')
# Reparameterize for inference
logger.info("⚡ Reparameterizing model for inference...")
model.eval()
model = reparameterize_model(model)
logger.info("✅ Model loaded and optimized successfully!")
except Exception as e:
logger.error(f"❌ Failed to load model: {str(e)}")
raise e
def download_image(url: str) -> Image.Image:
"""Download image from URL"""
try:
response = requests.get(url, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
return image.convert('RGB')
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}")
def get_image_embedding(image: Image.Image):
"""Get embedding for an image"""
try:
image_tensor = preprocess(image).unsqueeze(0)
with torch.no_grad():
image_features = model.encode_image(image_tensor)
# Normalize the embedding
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.squeeze().cpu()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to process image: {str(e)}")
def get_text_embedding(text: str):
"""Get embedding for text"""
try:
text_tokens = tokenizer([text])
with torch.no_grad():
text_features = model.encode_text(text_tokens)
# Normalize the embedding
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.squeeze().cpu()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to process text: {str(e)}")
def calculate_similarity(embedding1: np.ndarray, embedding2: np.ndarray) -> float:
"""Calculate cosine similarity between two embeddings"""
return float(np.dot(embedding1, embedding2))
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
logger.info("🚀 Starting MobileCLIP API...")
logger.info("📦 Loading model - this may take 2-5 minutes...")
load_model()
logger.info("✅ Model loaded successfully! API is ready.")
@app.get("/")
async def root():
"""Health check endpoint"""
return {"message": "MobileCLIP API is running!", "status": "healthy"}
@app.post("/image-embedding", response_model=EmbeddingResponse)
async def image_embedding(request: ImageRequest):
"""Get embedding for an image given its URL"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
image = download_image(request.image_url)
embedding = get_image_embedding(image)
return EmbeddingResponse(embedding=embedding.tolist())
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in image_embedding: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@app.post("/text-embedding", response_model=EmbeddingResponse)
async def text_embedding(request: TextRequest):
"""Get embedding for text"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
embedding = get_text_embedding(request.text)
return EmbeddingResponse(embedding=embedding.tolist())
except Exception as e:
logger.error(f"Error in text_embedding: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
@app.post("/similarity", response_model=SimilarityResponse)
async def similarity(request: SimilarityRequest):
"""Calculate similarity between image and text"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
image = download_image(request.image_url)
image_embedding = get_image_embedding(image)
text_embedding = get_text_embedding(request.text)
similarity_score = calculate_similarity(image_embedding, text_embedding)
return SimilarityResponse(similarity=similarity_score)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in similarity: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |