from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from typing import List from model import model_instance import time import logging # Logging setup logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="Sentence Embedding API", description="Aapke trained model se text embedding nikaalne ka API", version="1.0.0" ) # Request body ka structure class TextInput(BaseModel): text: str = Field(..., min_length=1, max_length=512, example="Mera naam Bahadur hai") class EmbeddingResponse(BaseModel): embedding: List[float] input_text: str inference_time_ms: float # Health check endpoint @app.get("/") def root(): return {"message": "API is running! Go to /docs for Swagger UI"} @app.get("/health") def health_check(): return {"status": "healthy", "model_loaded": True} # Main prediction endpoint @app.post("/embed", response_model=EmbeddingResponse) async def get_embedding(input_data: TextInput): try: logger.info(f"Processing text: {input_data.text[:50]}...") start_time = time.time() embedding = model_instance.get_embedding(input_data.text) inference_time = (time.time() - start_time) * 1000 # milliseconds return EmbeddingResponse( embedding=embedding, input_text=input_data.text, inference_time_ms=round(inference_time, 2) ) except Exception as e: logger.error(f"Error: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # Batch processing (optional) class BatchTextInput(BaseModel): texts: List[str] @app.post("/embed/batch") async def get_batch_embeddings(input_data: BatchTextInput): results = [] for text in input_data.texts: embedding = model_instance.get_embedding(text) results.append({ "text": text, "embedding": embedding }) return {"results": results, "count": len(results)}