Spaces:
Sleeping
Sleeping
| from typing import List | |
| from fastapi import FastAPI, Query, UploadFile, File, HTTPException | |
| from fastapi.responses import RedirectResponse | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import io | |
| import csv | |
| import os | |
| from clinical_embedding import ClinicalBERT | |
| # Pydantic models for request/response | |
| class EmbeddingRequest(BaseModel): | |
| sentences: List[str] | |
| pooling: str = 'cls' | |
| class EmbeddingResponse(BaseModel): | |
| embeddings: List[List[float]] | |
| shape: List[int] | |
| pooling: str | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Clinical BERT Embeddings API", | |
| description="API for generating embeddings using Bio_ClinicalBERT model", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware to allow web page access | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["https://santanche-clinical-embedding.hf.space"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Serve static files | |
| app.mount("/app/static", StaticFiles(directory="static"), name="static") | |
| # Initialize model (global instance) | |
| clinical_bert = None | |
| async def startup_event(): | |
| """Load model on startup""" | |
| global clinical_bert | |
| clinical_bert = ClinicalBERT(device=-1) # Use device=0 for GPU | |
| async def root(): | |
| return RedirectResponse(url="/browser/") | |
| def get_browser(): | |
| print(os.path.join("static", "browser", "index.html")) | |
| return FileResponse(os.path.join("static", "browser", "index.html")) | |
| async def get_embeddings( | |
| sentences: List[str] = Query(..., description="List of sentences to embed"), | |
| pooling: str = Query('cls', description="Pooling strategy: mean, cls, or max") | |
| ): | |
| """ | |
| Generate embeddings for a list of sentences. | |
| Args: | |
| sentences: List of input sentences | |
| pooling: Pooling strategy ('mean', 'cls', or 'max') | |
| Returns: | |
| EmbeddingResponse with embeddings and metadata | |
| """ | |
| # Validate pooling method | |
| if pooling not in ['mean', 'cls', 'max']: | |
| return { | |
| "error": "Invalid pooling method. Choose from: mean, cls, max" | |
| } | |
| # Generate embeddings | |
| embeddings = clinical_bert.get_embeddings(sentences, pooling=pooling) | |
| # Convert to list for JSON serialization | |
| embeddings_list = embeddings.tolist() | |
| return EmbeddingResponse( | |
| embeddings=embeddings_list, | |
| shape=list(embeddings.shape), | |
| pooling=pooling | |
| ) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": clinical_bert is not None | |
| } | |
| async def post_embeddings_batch(request: EmbeddingRequest): | |
| """ | |
| POST endpoint for batch embedding generation. | |
| Args: | |
| request: EmbeddingRequest with sentences and pooling method | |
| Returns: | |
| EmbeddingResponse with embeddings and metadata | |
| """ | |
| # Validate pooling method | |
| if request.pooling not in ['mean', 'cls', 'max']: | |
| raise HTTPException(status_code=400, detail="Invalid pooling method. Choose from: mean, cls, max") | |
| # Generate embeddings | |
| embeddings = clinical_bert.get_embeddings(request.sentences, pooling=request.pooling) | |
| # Convert to list for JSON serialization | |
| embeddings_list = embeddings.tolist() | |
| return EmbeddingResponse( | |
| embeddings=embeddings_list, | |
| shape=list(embeddings.shape), | |
| pooling=request.pooling | |
| ) | |
| async def upload_file_embeddings( | |
| file: UploadFile = File(...), | |
| pooling: str = Query('cls', description="Pooling strategy: mean, cls, or max") | |
| ): | |
| """ | |
| Upload a CSV file with terms and get embeddings back as CSV. | |
| Args: | |
| file: CSV file with one column containing terms | |
| pooling: Pooling strategy ('mean', 'cls', or 'max') | |
| Returns: | |
| CSV file with embeddings | |
| """ | |
| # Validate file type | |
| if not file.filename.endswith('.csv'): | |
| raise HTTPException(status_code=400, detail="File must be a CSV") | |
| # Validate pooling method | |
| if pooling not in ['mean', 'cls', 'max']: | |
| raise HTTPException(status_code=400, detail="Invalid pooling method. Choose from: mean, cls, max") | |
| try: | |
| # Read CSV file | |
| contents = await file.read() | |
| csv_reader = csv.DictReader(io.StringIO(contents.decode('utf-8'))) | |
| # Get column name (first column) | |
| fieldnames = csv_reader.fieldnames | |
| if not fieldnames or len(fieldnames) == 0: | |
| raise HTTPException(status_code=400, detail="CSV must have at least one column") | |
| column_name = fieldnames[0] | |
| # Extract terms | |
| terms = [row[column_name] for row in csv_reader if row[column_name].strip()] | |
| if not terms: | |
| raise HTTPException(status_code=400, detail="No terms found in CSV") | |
| # Generate embeddings | |
| embeddings = clinical_bert.get_embeddings(terms, pooling=pooling) | |
| # Create output CSV | |
| output = io.StringIO() | |
| writer = csv.writer(output) | |
| # Write header (term + embedding dimensions) | |
| header = [column_name] + [f"dim_{i}" for i in range(embeddings.shape[1])] | |
| writer.writerow(header) | |
| # Write rows | |
| for term, embedding in zip(terms, embeddings): | |
| row = [term] + embedding.tolist() | |
| writer.writerow(row) | |
| # Prepare response | |
| output.seek(0) | |
| return StreamingResponse( | |
| io.BytesIO(output.getvalue().encode('utf-8')), | |
| media_type="text/csv", | |
| headers={"Content-Disposition": f"attachment; filename=embeddings_{file.filename}"} | |
| ) | |
| except UnicodeDecodeError: | |
| raise HTTPException(status_code=400, detail="File must be UTF-8 encoded") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") | |
| if __name__ == "__main__": | |
| # Run the server | |
| uvicorn.run( | |
| "main:app", | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=False | |
| ) | |