from fastapi import FastAPI, File, UploadFile import torch from dotenv import load_dotenv import logging import os from PIL import Image from transformers import CLIPProcessor, CLIPModel load_dotenv() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title='CLIP API', description='Returns CLIP embedding for text and image') HF_TOKEN = os.getenv('hf_token') logger.info("Loading CLIP processor and model") try: processor = CLIPProcessor.from_pretrained( "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN ) clip_model = CLIPModel.from_pretrained( "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN) clip_model.eval() logger.info("CLIP model loaded successfully") except Exception as e: logger.error(f"Failed to load CLIP model : {e}") raise def get_text_embedding(text: str): try: inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): text_embeddings = clip_model.get_text_features(**inputs) logger.info("Text embedding generated") return text_embeddings.squeeze(0).tolist() except Exception as e: logger.error(f"Error while generating embedding : {e}") raise @app.get("/") async def root(): logger.info("Root endpoint accessed") return {"message": "Welcome to the CLIP embedding API."} @app.get("/embedding") async def get_embedding_text(text: str): logger.info(f"Embedding endpoint called with text") embedding = get_text_embedding(text) return {"embedding": embedding, "dimension": len(embedding)} @app.post("/clip/process") async def process_image(file: UploadFile = File(...)): logger.info("Processing image") image = Image.open(file.file).convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): embeddings = clip_model.get_image_features(**inputs) return {"embedding": embeddings.tolist()}