Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile | |
| import torch | |
| from dotenv import load_dotenv | |
| import logging | |
| import os | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| load_dotenv() | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI(title='CLIP API', | |
| description='Returns CLIP embedding for text and image') | |
| HF_TOKEN = os.getenv('hf_token') | |
| logger.info("Loading CLIP processor and model") | |
| try: | |
| processor = CLIPProcessor.from_pretrained( | |
| "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN | |
| ) | |
| clip_model = CLIPModel.from_pretrained( | |
| "openai/clip-vit-large-patch14", use_auth_token=HF_TOKEN) | |
| clip_model.eval() | |
| logger.info("CLIP model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load CLIP model : {e}") | |
| raise | |
| def get_text_embedding(text: str): | |
| try: | |
| inputs = processor(text=[text], return_tensors="pt", | |
| padding=True, truncation=True) | |
| with torch.no_grad(): | |
| text_embeddings = clip_model.get_text_features(**inputs) | |
| logger.info("Text embedding generated") | |
| return text_embeddings.squeeze(0).tolist() | |
| except Exception as e: | |
| logger.error(f"Error while generating embedding : {e}") | |
| raise | |
| async def root(): | |
| logger.info("Root endpoint accessed") | |
| return {"message": "Welcome to the CLIP embedding API."} | |
| async def get_embedding_text(text: str): | |
| logger.info(f"Embedding endpoint called with text") | |
| embedding = get_text_embedding(text) | |
| return {"embedding": embedding, "dimension": len(embedding)} | |
| async def process_image(file: UploadFile = File(...)): | |
| logger.info("Processing image") | |
| image = Image.open(file.file).convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| embeddings = clip_model.get_image_features(**inputs) | |
| return {"embedding": embeddings.tolist()} | |