from fastapi import FastAPI, Request, Response, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn import torch import json import base64 from io import BytesIO from PIL import Image import requests from typing import List, Dict, Any, Union, Optional from pydantic import BaseModel, Field import numpy as np import os # Import handler from handler import ModelHandler app = FastAPI(title="Embedding Model API") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Initialize model handler model_handler = ModelHandler() model_handler.initialize(None) # We'll handle context manually # Define request/response models class TextInput(BaseModel): text: str = Field(..., description="The text to generate embeddings for") class ImageInput(BaseModel): image: str = Field(..., description="URL or base64-encoded image to generate embeddings for") class EmbeddingRequest(BaseModel): inputs: List[Union[TextInput, ImageInput]] = Field(..., description="List of text or image inputs") task: str = Field("retrieval", description="Task type: retrieval, text-matching, or code") class EmbeddingResponse(BaseModel): embeddings: List[List[float]] = Field(..., description="List of embeddings") @app.get("/") async def root(): return {"message": "Embedding Model API is running"} @app.post("/embeddings", response_model=EmbeddingResponse) async def create_embeddings(request: EmbeddingRequest): try: inputs = [] # Process inputs for item in request.inputs: if hasattr(item, "text"): inputs.append(item.text) elif hasattr(item, "image"): image_data = item.image if image_data.startswith("http"): # URL response = requests.get(image_data) image = Image.open(BytesIO(response.content)).convert("RGB") elif image_data.startswith("data:image"): # Base64 image_b64 = image_data.split(",")[1] image = Image.open(BytesIO(base64.b64decode(image_b64))).convert("RGB") else: raise HTTPException(status_code=400, detail="Invalid image format") inputs.append(image) # Get embeddings features = model_handler.model.tokenize(inputs) outputs = model_handler.model.forward(features, task=request.task) embeddings = outputs.get("sentence_embedding", None) if embeddings is None: raise HTTPException(status_code=500, detail="Failed to generate embeddings") # Convert to list for JSON serialization embeddings_list = embeddings.cpu().numpy().tolist() return {"embeddings": embeddings_list} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": # Run the API server port = int(os.environ.get("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)