| from fastapi import FastAPI, HTTPException
|
| from pydantic import BaseModel, Field
|
| from typing import List
|
| from model import model_instance
|
| import time
|
| import logging
|
|
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| logger = logging.getLogger(__name__)
|
|
|
| app = FastAPI(
|
| title="Sentence Embedding API",
|
| description="Aapke trained model se text embedding nikaalne ka API",
|
| version="1.0.0"
|
| )
|
|
|
|
|
| class TextInput(BaseModel):
|
| text: str = Field(..., min_length=1, max_length=512, example="Mera naam Bahadur hai")
|
|
|
| class EmbeddingResponse(BaseModel):
|
| embedding: List[float]
|
| input_text: str
|
| inference_time_ms: float
|
|
|
|
|
| @app.get("/")
|
| def root():
|
| return {"message": "API is running! Go to /docs for Swagger UI"}
|
|
|
| @app.get("/health")
|
| def health_check():
|
| return {"status": "healthy", "model_loaded": True}
|
|
|
|
|
| @app.post("/embed", response_model=EmbeddingResponse)
|
| async def get_embedding(input_data: TextInput):
|
| try:
|
| logger.info(f"Processing text: {input_data.text[:50]}...")
|
|
|
| start_time = time.time()
|
| embedding = model_instance.get_embedding(input_data.text)
|
| inference_time = (time.time() - start_time) * 1000
|
|
|
| return EmbeddingResponse(
|
| embedding=embedding,
|
| input_text=input_data.text,
|
| inference_time_ms=round(inference_time, 2)
|
| )
|
|
|
| except Exception as e:
|
| logger.error(f"Error: {str(e)}")
|
| raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
| class BatchTextInput(BaseModel):
|
| texts: List[str]
|
|
|
| @app.post("/embed/batch")
|
| async def get_batch_embeddings(input_data: BatchTextInput):
|
| results = []
|
| for text in input_data.texts:
|
| embedding = model_instance.get_embedding(text)
|
| results.append({
|
| "text": text,
|
| "embedding": embedding
|
| })
|
| return {"results": results, "count": len(results)} |