Spaces:
Sleeping
Sleeping
feat (start): initial setup
Browse files- .gitignore +2 -0
- Dockerfile +13 -0
- app/clinical_embedding.py +67 -0
- app/server_clinical_embedding.py +100 -0
- app/test_clinical_embedding.py +22 -0
- requirements.txt +12 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
.venv
|
Dockerfile
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9
|
| 2 |
+
|
| 3 |
+
RUN useradd -m -u 1000 user
|
| 4 |
+
USER user
|
| 5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
| 10 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 11 |
+
|
| 12 |
+
COPY --chown=user ./app /app
|
| 13 |
+
CMD ["uvicorn", "server_sentiment_analysis:app", "--host", "0.0.0.0", "--port", "7860"]
|
app/clinical_embedding.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from transformers import pipeline
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
class ClinicalBERT:
|
| 6 |
+
"""
|
| 7 |
+
A wrapper class for Bio_ClinicalBERT model to generate sentence embeddings.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(self, model_name: str = "emilyalsentzer/Bio_ClinicalBERT", device: int = -1):
|
| 11 |
+
"""
|
| 12 |
+
Initialize the ClinicalBERT model using pipeline.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
model_name: The Hugging Face model identifier
|
| 16 |
+
device: Device to run the model on (-1 for CPU, 0 for first GPU, etc.)
|
| 17 |
+
"""
|
| 18 |
+
self.model_name = model_name
|
| 19 |
+
|
| 20 |
+
# Create feature extraction pipeline
|
| 21 |
+
print(f"Loading {model_name}...")
|
| 22 |
+
self.pipe = pipeline(
|
| 23 |
+
"feature-extraction",
|
| 24 |
+
model=model_name,
|
| 25 |
+
device=device
|
| 26 |
+
)
|
| 27 |
+
print(f"Model loaded successfully on device {device}")
|
| 28 |
+
|
| 29 |
+
def get_embeddings(self, sentences: List[str], pooling: str = 'mean') -> np.ndarray:
|
| 30 |
+
"""
|
| 31 |
+
Generate embeddings for a list of sentences.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
sentences: List of input sentences
|
| 35 |
+
pooling: Pooling strategy ('mean', 'cls', or 'max')
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
numpy array of shape (num_sentences, embedding_dim)
|
| 39 |
+
"""
|
| 40 |
+
if not sentences:
|
| 41 |
+
return np.array([])
|
| 42 |
+
|
| 43 |
+
# Get embeddings from pipeline
|
| 44 |
+
# The pipeline returns a list with shape (1, num_tokens, embedding_dim) per sentence
|
| 45 |
+
outputs = self.pipe(sentences)
|
| 46 |
+
|
| 47 |
+
# Apply pooling strategy to each sentence
|
| 48 |
+
embeddings = []
|
| 49 |
+
for sentence_output in outputs:
|
| 50 |
+
# Convert to numpy array and squeeze the first dimension
|
| 51 |
+
# Shape: (1, num_tokens, embedding_dim) -> (num_tokens, embedding_dim)
|
| 52 |
+
tokens_array = np.array(sentence_output).squeeze(0)
|
| 53 |
+
|
| 54 |
+
if pooling == 'cls':
|
| 55 |
+
# Use [CLS] token (first token)
|
| 56 |
+
embedding = tokens_array[0]
|
| 57 |
+
elif pooling == 'max':
|
| 58 |
+
# Max pooling across tokens (dim 0)
|
| 59 |
+
embedding = np.max(tokens_array, axis=0)
|
| 60 |
+
else: # mean pooling (default)
|
| 61 |
+
# Average across all tokens (dim 0)
|
| 62 |
+
embedding = np.mean(tokens_array, axis=0)
|
| 63 |
+
|
| 64 |
+
embeddings.append(embedding)
|
| 65 |
+
|
| 66 |
+
# Stack embeddings into a 2D array: (num_sentences, embedding_dim)
|
| 67 |
+
return np.vstack(embeddings)
|
app/server_clinical_embedding.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
from fastapi import FastAPI, Query
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
import uvicorn
|
| 5 |
+
|
| 6 |
+
from clinical_embedding import ClinicalBERT
|
| 7 |
+
|
| 8 |
+
# Pydantic models for request/response
|
| 9 |
+
class EmbeddingRequest(BaseModel):
|
| 10 |
+
sentences: List[str]
|
| 11 |
+
pooling: str = 'mean'
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class EmbeddingResponse(BaseModel):
|
| 15 |
+
embeddings: List[List[float]]
|
| 16 |
+
shape: List[int]
|
| 17 |
+
pooling: str
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# Initialize FastAPI app
|
| 21 |
+
app = FastAPI(
|
| 22 |
+
title="Clinical BERT Embeddings API",
|
| 23 |
+
description="API for generating embeddings using Bio_ClinicalBERT model",
|
| 24 |
+
version="1.0.0"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Initialize model (global instance)
|
| 28 |
+
clinical_bert = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@app.on_event("startup")
|
| 32 |
+
async def startup_event():
|
| 33 |
+
"""Load model on startup"""
|
| 34 |
+
global clinical_bert
|
| 35 |
+
clinical_bert = ClinicalBERT(device=-1) # Use device=0 for GPU
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@app.get("/")
|
| 39 |
+
async def root():
|
| 40 |
+
"""Root endpoint with API information"""
|
| 41 |
+
return {
|
| 42 |
+
"message": "Clinical BERT Embeddings API",
|
| 43 |
+
"endpoints": {
|
| 44 |
+
"/embeddings": "GET - Generate embeddings from sentences",
|
| 45 |
+
"/docs": "GET - Interactive API documentation"
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@app.get("/embeddings", response_model=EmbeddingResponse)
|
| 51 |
+
async def get_embeddings(
|
| 52 |
+
sentences: List[str] = Query(..., description="List of sentences to embed"),
|
| 53 |
+
pooling: str = Query('mean', description="Pooling strategy: mean, cls, or max")
|
| 54 |
+
):
|
| 55 |
+
"""
|
| 56 |
+
Generate embeddings for a list of sentences.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
sentences: List of input sentences
|
| 60 |
+
pooling: Pooling strategy ('mean', 'cls', or 'max')
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
EmbeddingResponse with embeddings and metadata
|
| 64 |
+
"""
|
| 65 |
+
# Validate pooling method
|
| 66 |
+
if pooling not in ['mean', 'cls', 'max']:
|
| 67 |
+
return {
|
| 68 |
+
"error": "Invalid pooling method. Choose from: mean, cls, max"
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
# Generate embeddings
|
| 72 |
+
embeddings = clinical_bert.get_embeddings(sentences, pooling=pooling)
|
| 73 |
+
|
| 74 |
+
# Convert to list for JSON serialization
|
| 75 |
+
embeddings_list = embeddings.tolist()
|
| 76 |
+
|
| 77 |
+
return EmbeddingResponse(
|
| 78 |
+
embeddings=embeddings_list,
|
| 79 |
+
shape=list(embeddings.shape),
|
| 80 |
+
pooling=pooling
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@app.get("/health")
|
| 85 |
+
async def health_check():
|
| 86 |
+
"""Health check endpoint"""
|
| 87 |
+
return {
|
| 88 |
+
"status": "healthy",
|
| 89 |
+
"model_loaded": clinical_bert is not None
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
# Run the server
|
| 95 |
+
uvicorn.run(
|
| 96 |
+
"main:app",
|
| 97 |
+
host="0.0.0.0",
|
| 98 |
+
port=8000,
|
| 99 |
+
reload=False
|
| 100 |
+
)
|
app/test_clinical_embedding.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from clinical_embedding import ClinicalBERT
|
| 2 |
+
|
| 3 |
+
# Initialize the model (use device=0 for GPU, device=-1 for CPU)
|
| 4 |
+
clinical_bert = ClinicalBERT(device=-1)
|
| 5 |
+
|
| 6 |
+
# Example sentences
|
| 7 |
+
sentences = [
|
| 8 |
+
"Heart Attack",
|
| 9 |
+
"Myocardial Infarction"
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
# Get embeddings with mean pooling
|
| 13 |
+
embeddings = clinical_bert.get_embeddings(sentences, pooling='mean')
|
| 14 |
+
print(f"Embeddings shape: {embeddings.shape}")
|
| 15 |
+
print(f"First embedding (truncated): {embeddings[0][:5]}...")
|
| 16 |
+
|
| 17 |
+
# Try different pooling strategies
|
| 18 |
+
embeddings_cls = clinical_bert.get_embeddings(sentences, pooling='cls')
|
| 19 |
+
print(f"\nCLS pooling shape: {embeddings_cls.shape}")
|
| 20 |
+
|
| 21 |
+
embeddings_max = clinical_bert.get_embeddings(sentences, pooling='max')
|
| 22 |
+
print(f"Max pooling shape: {embeddings_max.shape}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Web Framework
|
| 2 |
+
fastapi==0.104.1
|
| 3 |
+
uvicorn[standard]==0.24.0
|
| 4 |
+
pydantic==2.5.0
|
| 5 |
+
|
| 6 |
+
# Machine Learning
|
| 7 |
+
transformers==4.35.2
|
| 8 |
+
torch==2.1.1
|
| 9 |
+
numpy==1.24.3
|
| 10 |
+
|
| 11 |
+
# Optional: for GPU support, also install:
|
| 12 |
+
# torch==2.1.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html
|