test_model / app.py
embedingHF's picture
Upload folder using huggingface_hub
4225683 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List
from model import model_instance
import time
import logging
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Sentence Embedding API",
description="Aapke trained model se text embedding nikaalne ka API",
version="1.0.0"
)
# Request body ka structure
class TextInput(BaseModel):
text: str = Field(..., min_length=1, max_length=512, example="Mera naam Bahadur hai")
class EmbeddingResponse(BaseModel):
embedding: List[float]
input_text: str
inference_time_ms: float
# Health check endpoint
@app.get("/")
def root():
return {"message": "API is running! Go to /docs for Swagger UI"}
@app.get("/health")
def health_check():
return {"status": "healthy", "model_loaded": True}
# Main prediction endpoint
@app.post("/embed", response_model=EmbeddingResponse)
async def get_embedding(input_data: TextInput):
try:
logger.info(f"Processing text: {input_data.text[:50]}...")
start_time = time.time()
embedding = model_instance.get_embedding(input_data.text)
inference_time = (time.time() - start_time) * 1000 # milliseconds
return EmbeddingResponse(
embedding=embedding,
input_text=input_data.text,
inference_time_ms=round(inference_time, 2)
)
except Exception as e:
logger.error(f"Error: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Batch processing (optional)
class BatchTextInput(BaseModel):
texts: List[str]
@app.post("/embed/batch")
async def get_batch_embeddings(input_data: BatchTextInput):
results = []
for text in input_data.texts:
embedding = model_instance.get_embedding(text)
results.append({
"text": text,
"embedding": embedding
})
return {"results": results, "count": len(results)}