santanche commited on
Commit
29ba00a
·
1 Parent(s): 958c684

feat (start): initial setup

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ .venv
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user ./app /app
13
+ CMD ["uvicorn", "server_sentiment_analysis:app", "--host", "0.0.0.0", "--port", "7860"]
app/clinical_embedding.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from transformers import pipeline
3
+ from typing import List
4
+
5
+ class ClinicalBERT:
6
+ """
7
+ A wrapper class for Bio_ClinicalBERT model to generate sentence embeddings.
8
+ """
9
+
10
+ def __init__(self, model_name: str = "emilyalsentzer/Bio_ClinicalBERT", device: int = -1):
11
+ """
12
+ Initialize the ClinicalBERT model using pipeline.
13
+
14
+ Args:
15
+ model_name: The Hugging Face model identifier
16
+ device: Device to run the model on (-1 for CPU, 0 for first GPU, etc.)
17
+ """
18
+ self.model_name = model_name
19
+
20
+ # Create feature extraction pipeline
21
+ print(f"Loading {model_name}...")
22
+ self.pipe = pipeline(
23
+ "feature-extraction",
24
+ model=model_name,
25
+ device=device
26
+ )
27
+ print(f"Model loaded successfully on device {device}")
28
+
29
+ def get_embeddings(self, sentences: List[str], pooling: str = 'mean') -> np.ndarray:
30
+ """
31
+ Generate embeddings for a list of sentences.
32
+
33
+ Args:
34
+ sentences: List of input sentences
35
+ pooling: Pooling strategy ('mean', 'cls', or 'max')
36
+
37
+ Returns:
38
+ numpy array of shape (num_sentences, embedding_dim)
39
+ """
40
+ if not sentences:
41
+ return np.array([])
42
+
43
+ # Get embeddings from pipeline
44
+ # The pipeline returns a list with shape (1, num_tokens, embedding_dim) per sentence
45
+ outputs = self.pipe(sentences)
46
+
47
+ # Apply pooling strategy to each sentence
48
+ embeddings = []
49
+ for sentence_output in outputs:
50
+ # Convert to numpy array and squeeze the first dimension
51
+ # Shape: (1, num_tokens, embedding_dim) -> (num_tokens, embedding_dim)
52
+ tokens_array = np.array(sentence_output).squeeze(0)
53
+
54
+ if pooling == 'cls':
55
+ # Use [CLS] token (first token)
56
+ embedding = tokens_array[0]
57
+ elif pooling == 'max':
58
+ # Max pooling across tokens (dim 0)
59
+ embedding = np.max(tokens_array, axis=0)
60
+ else: # mean pooling (default)
61
+ # Average across all tokens (dim 0)
62
+ embedding = np.mean(tokens_array, axis=0)
63
+
64
+ embeddings.append(embedding)
65
+
66
+ # Stack embeddings into a 2D array: (num_sentences, embedding_dim)
67
+ return np.vstack(embeddings)
app/server_clinical_embedding.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from fastapi import FastAPI, Query
3
+ from pydantic import BaseModel
4
+ import uvicorn
5
+
6
+ from clinical_embedding import ClinicalBERT
7
+
8
+ # Pydantic models for request/response
9
+ class EmbeddingRequest(BaseModel):
10
+ sentences: List[str]
11
+ pooling: str = 'mean'
12
+
13
+
14
+ class EmbeddingResponse(BaseModel):
15
+ embeddings: List[List[float]]
16
+ shape: List[int]
17
+ pooling: str
18
+
19
+
20
+ # Initialize FastAPI app
21
+ app = FastAPI(
22
+ title="Clinical BERT Embeddings API",
23
+ description="API for generating embeddings using Bio_ClinicalBERT model",
24
+ version="1.0.0"
25
+ )
26
+
27
+ # Initialize model (global instance)
28
+ clinical_bert = None
29
+
30
+
31
+ @app.on_event("startup")
32
+ async def startup_event():
33
+ """Load model on startup"""
34
+ global clinical_bert
35
+ clinical_bert = ClinicalBERT(device=-1) # Use device=0 for GPU
36
+
37
+
38
+ @app.get("/")
39
+ async def root():
40
+ """Root endpoint with API information"""
41
+ return {
42
+ "message": "Clinical BERT Embeddings API",
43
+ "endpoints": {
44
+ "/embeddings": "GET - Generate embeddings from sentences",
45
+ "/docs": "GET - Interactive API documentation"
46
+ }
47
+ }
48
+
49
+
50
+ @app.get("/embeddings", response_model=EmbeddingResponse)
51
+ async def get_embeddings(
52
+ sentences: List[str] = Query(..., description="List of sentences to embed"),
53
+ pooling: str = Query('mean', description="Pooling strategy: mean, cls, or max")
54
+ ):
55
+ """
56
+ Generate embeddings for a list of sentences.
57
+
58
+ Args:
59
+ sentences: List of input sentences
60
+ pooling: Pooling strategy ('mean', 'cls', or 'max')
61
+
62
+ Returns:
63
+ EmbeddingResponse with embeddings and metadata
64
+ """
65
+ # Validate pooling method
66
+ if pooling not in ['mean', 'cls', 'max']:
67
+ return {
68
+ "error": "Invalid pooling method. Choose from: mean, cls, max"
69
+ }
70
+
71
+ # Generate embeddings
72
+ embeddings = clinical_bert.get_embeddings(sentences, pooling=pooling)
73
+
74
+ # Convert to list for JSON serialization
75
+ embeddings_list = embeddings.tolist()
76
+
77
+ return EmbeddingResponse(
78
+ embeddings=embeddings_list,
79
+ shape=list(embeddings.shape),
80
+ pooling=pooling
81
+ )
82
+
83
+
84
+ @app.get("/health")
85
+ async def health_check():
86
+ """Health check endpoint"""
87
+ return {
88
+ "status": "healthy",
89
+ "model_loaded": clinical_bert is not None
90
+ }
91
+
92
+
93
+ if __name__ == "__main__":
94
+ # Run the server
95
+ uvicorn.run(
96
+ "main:app",
97
+ host="0.0.0.0",
98
+ port=8000,
99
+ reload=False
100
+ )
app/test_clinical_embedding.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from clinical_embedding import ClinicalBERT
2
+
3
+ # Initialize the model (use device=0 for GPU, device=-1 for CPU)
4
+ clinical_bert = ClinicalBERT(device=-1)
5
+
6
+ # Example sentences
7
+ sentences = [
8
+ "Heart Attack",
9
+ "Myocardial Infarction"
10
+ ]
11
+
12
+ # Get embeddings with mean pooling
13
+ embeddings = clinical_bert.get_embeddings(sentences, pooling='mean')
14
+ print(f"Embeddings shape: {embeddings.shape}")
15
+ print(f"First embedding (truncated): {embeddings[0][:5]}...")
16
+
17
+ # Try different pooling strategies
18
+ embeddings_cls = clinical_bert.get_embeddings(sentences, pooling='cls')
19
+ print(f"\nCLS pooling shape: {embeddings_cls.shape}")
20
+
21
+ embeddings_max = clinical_bert.get_embeddings(sentences, pooling='max')
22
+ print(f"Max pooling shape: {embeddings_max.shape}")
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Web Framework
2
+ fastapi==0.104.1
3
+ uvicorn[standard]==0.24.0
4
+ pydantic==2.5.0
5
+
6
+ # Machine Learning
7
+ transformers==4.35.2
8
+ torch==2.1.1
9
+ numpy==1.24.3
10
+
11
+ # Optional: for GPU support, also install:
12
+ # torch==2.1.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html