clinical-embedding / app /server_clinical_embedding.py
santanche's picture
update (rest): restrict allowed origins
2d5dc60
from typing import List
from fastapi import FastAPI, Query, UploadFile, File, HTTPException
from fastapi.responses import RedirectResponse
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import uvicorn
import io
import csv
import os
from clinical_embedding import ClinicalBERT
# Pydantic models for request/response
class EmbeddingRequest(BaseModel):
sentences: List[str]
pooling: str = 'cls'
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
shape: List[int]
pooling: str
# Initialize FastAPI app
app = FastAPI(
title="Clinical BERT Embeddings API",
description="API for generating embeddings using Bio_ClinicalBERT model",
version="1.0.0"
)
# Add CORS middleware to allow web page access
app.add_middleware(
CORSMiddleware,
allow_origins=["https://santanche-clinical-embedding.hf.space"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Serve static files
app.mount("/app/static", StaticFiles(directory="static"), name="static")
# Initialize model (global instance)
clinical_bert = None
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
global clinical_bert
clinical_bert = ClinicalBERT(device=-1) # Use device=0 for GPU
@app.get("/")
async def root():
return RedirectResponse(url="/browser/")
@app.get("/browser/")
def get_browser():
print(os.path.join("static", "browser", "index.html"))
return FileResponse(os.path.join("static", "browser", "index.html"))
@app.get("/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(
sentences: List[str] = Query(..., description="List of sentences to embed"),
pooling: str = Query('cls', description="Pooling strategy: mean, cls, or max")
):
"""
Generate embeddings for a list of sentences.
Args:
sentences: List of input sentences
pooling: Pooling strategy ('mean', 'cls', or 'max')
Returns:
EmbeddingResponse with embeddings and metadata
"""
# Validate pooling method
if pooling not in ['mean', 'cls', 'max']:
return {
"error": "Invalid pooling method. Choose from: mean, cls, max"
}
# Generate embeddings
embeddings = clinical_bert.get_embeddings(sentences, pooling=pooling)
# Convert to list for JSON serialization
embeddings_list = embeddings.tolist()
return EmbeddingResponse(
embeddings=embeddings_list,
shape=list(embeddings.shape),
pooling=pooling
)
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"model_loaded": clinical_bert is not None
}
@app.post("/embeddings/batch")
async def post_embeddings_batch(request: EmbeddingRequest):
"""
POST endpoint for batch embedding generation.
Args:
request: EmbeddingRequest with sentences and pooling method
Returns:
EmbeddingResponse with embeddings and metadata
"""
# Validate pooling method
if request.pooling not in ['mean', 'cls', 'max']:
raise HTTPException(status_code=400, detail="Invalid pooling method. Choose from: mean, cls, max")
# Generate embeddings
embeddings = clinical_bert.get_embeddings(request.sentences, pooling=request.pooling)
# Convert to list for JSON serialization
embeddings_list = embeddings.tolist()
return EmbeddingResponse(
embeddings=embeddings_list,
shape=list(embeddings.shape),
pooling=request.pooling
)
@app.post("/embeddings/file")
async def upload_file_embeddings(
file: UploadFile = File(...),
pooling: str = Query('cls', description="Pooling strategy: mean, cls, or max")
):
"""
Upload a CSV file with terms and get embeddings back as CSV.
Args:
file: CSV file with one column containing terms
pooling: Pooling strategy ('mean', 'cls', or 'max')
Returns:
CSV file with embeddings
"""
# Validate file type
if not file.filename.endswith('.csv'):
raise HTTPException(status_code=400, detail="File must be a CSV")
# Validate pooling method
if pooling not in ['mean', 'cls', 'max']:
raise HTTPException(status_code=400, detail="Invalid pooling method. Choose from: mean, cls, max")
try:
# Read CSV file
contents = await file.read()
csv_reader = csv.DictReader(io.StringIO(contents.decode('utf-8')))
# Get column name (first column)
fieldnames = csv_reader.fieldnames
if not fieldnames or len(fieldnames) == 0:
raise HTTPException(status_code=400, detail="CSV must have at least one column")
column_name = fieldnames[0]
# Extract terms
terms = [row[column_name] for row in csv_reader if row[column_name].strip()]
if not terms:
raise HTTPException(status_code=400, detail="No terms found in CSV")
# Generate embeddings
embeddings = clinical_bert.get_embeddings(terms, pooling=pooling)
# Create output CSV
output = io.StringIO()
writer = csv.writer(output)
# Write header (term + embedding dimensions)
header = [column_name] + [f"dim_{i}" for i in range(embeddings.shape[1])]
writer.writerow(header)
# Write rows
for term, embedding in zip(terms, embeddings):
row = [term] + embedding.tolist()
writer.writerow(row)
# Prepare response
output.seek(0)
return StreamingResponse(
io.BytesIO(output.getvalue().encode('utf-8')),
media_type="text/csv",
headers={"Content-Disposition": f"attachment; filename=embeddings_{file.filename}"}
)
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="File must be UTF-8 encoded")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
if __name__ == "__main__":
# Run the server
uvicorn.run(
"main:app",
host="0.0.0.0",
port=8000,
reload=False
)