santanche commited on
Commit
29dff86
·
1 Parent(s): 2d5dc60

feat (comparison): three models comparison

Browse files
app/clinical_embedding.py CHANGED
@@ -1,67 +1,282 @@
1
  import numpy as np
2
- from transformers import pipeline
3
- from typing import List
 
 
 
 
4
 
5
- class ClinicalBERT:
 
 
 
 
 
 
 
6
  """
7
- A wrapper class for Bio_ClinicalBERT model to generate sentence embeddings.
8
  """
9
-
10
- def __init__(self, model_name: str = "emilyalsentzer/Bio_ClinicalBERT", device: int = -1):
 
 
 
 
 
 
 
 
11
  """
12
- Initialize the ClinicalBERT model using pipeline.
13
-
14
- Args:
15
- model_name: The Hugging Face model identifier
16
- device: Device to run the model on (-1 for CPU, 0 for first GPU, etc.)
17
  """
18
- self.model_name = model_name
 
19
 
20
- # Create feature extraction pipeline
21
- print(f"Loading {model_name}...")
22
- self.pipe = pipeline(
23
- "feature-extraction",
24
- model=model_name,
25
- device=device
26
- )
27
- print(f"Model loaded successfully on device {device}")
28
-
29
- def get_embeddings(self, sentences: List[str], pooling: str = 'cls') -> np.ndarray:
30
- """
31
- Generate embeddings for a list of sentences.
32
 
33
- Args:
34
- sentences: List of input sentences
35
- pooling: Pooling strategy ('mean', 'cls', or 'max')
36
 
37
- Returns:
38
- numpy array of shape (num_sentences, embedding_dim)
39
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if not sentences:
41
  return np.array([])
 
 
42
 
43
- # Get embeddings from pipeline
44
- # The pipeline returns a list with shape (1, num_tokens, embedding_dim) per sentence
45
- outputs = self.pipe(sentences)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Apply pooling strategy to each sentence
48
- embeddings = []
49
- for sentence_output in outputs:
50
- # Convert to numpy array and squeeze the first dimension
51
- # Shape: (1, num_tokens, embedding_dim) -> (num_tokens, embedding_dim)
52
- tokens_array = np.array(sentence_output).squeeze(0)
53
-
54
- if pooling == 'cls':
55
- # Use [CLS] token (first token)
56
- embedding = tokens_array[0]
57
- elif pooling == 'max':
58
- # Max pooling across tokens (dim 0)
59
- embedding = np.max(tokens_array, axis=0)
60
- else: # mean pooling (default)
61
- # Average across all tokens (dim 0)
62
- embedding = np.mean(tokens_array, axis=0)
63
-
64
- embeddings.append(embedding)
 
 
 
 
65
 
66
- # Stack embeddings into a 2D array: (num_sentences, embedding_dim)
67
- return np.vstack(embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import torch
4
+ import re
5
+ from typing import List, Tuple, Union, Optional
6
+ import gensim.downloader as api
7
+ from abc import ABC, abstractmethod
8
 
9
+ class BaseEmbedder(ABC):
10
+ """Abstract base class for embedding models."""
11
+
12
+ @abstractmethod
13
+ def get_embeddings(self, sentences: List[str], pooling: str = 'cls') -> np.ndarray:
14
+ pass
15
+
16
+ class BertEmbedder(BaseEmbedder):
17
  """
18
+ Wrapper for BERT-based models (ClinicalBERT, BERT, etc.)
19
  """
20
+ def __init__(self, model_name: str, device: int = -1):
21
+ self.output_hidden_states = True
22
+ self.device = "cuda" if device == 0 and torch.cuda.is_available() else "cpu"
23
+ print(f"Loading {model_name} on {self.device}...")
24
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ self.model = AutoModel.from_pretrained(model_name).to(self.device)
26
+ self.model.eval()
27
+ print(f"Model {model_name} loaded successfully.")
28
+
29
+ def _extract_bracketed_content(self, text: str) -> Tuple[str, List[Tuple[int, int]]]:
30
  """
31
+ Extracts multiple bracketed contents '...[target]...' -> '...target...', [(start, end), ...]
32
+ Returns the CLEANED text (no brackets) and a list of character span ranges for the targets.
 
 
 
33
  """
34
+ # Finds all occurrences of [content]
35
+ # We need to construct the full string WITHOUT brackets, but keeping track of where the content was.
36
 
37
+ # Regex to find [content]
38
+ # We process manually to construct the clean string and map indices
 
 
 
 
 
 
 
 
 
 
39
 
40
+ clean_text = ""
41
+ target_spans = [] # List of (start_char, end_char) in clean_text
 
42
 
43
+ cursor = 0
44
+ i = 0
45
+ while i < len(text):
46
+ if text[i] == '[':
47
+ # possible start of bracket
48
+ # find matching ']'
49
+ end_bracket = text.find(']', i)
50
+ if end_bracket != -1:
51
+ # Found a bracket pair
52
+ # Append text before bracket
53
+ clean_text += text[cursor:i]
54
+
55
+ # Content inside
56
+ content = text[i+1:end_bracket]
57
+ start_span = len(clean_text)
58
+ clean_text += content
59
+ end_span = len(clean_text)
60
+ target_spans.append((start_span, end_span))
61
+
62
+ cursor = end_bracket + 1
63
+ i = end_bracket + 1
64
+ continue
65
+ i += 1
66
+
67
+ # Append remaining text
68
+ clean_text += text[cursor:]
69
+
70
+ # If no brackets found, return original text and span covering entire text
71
+ if not target_spans:
72
+ return text, [(0, len(text))]
73
+
74
+ return clean_text, target_spans
75
+
76
+ def get_embeddings(self, sentences: List[str], pooling: str = 'cls') -> np.ndarray:
77
  if not sentences:
78
  return np.array([])
79
+
80
+ embeddings_list = []
81
 
82
+ for sent in sentences:
83
+ # Handle bracketed parsing
84
+ clean_text, target_spans = self._extract_bracketed_content(sent)
85
+
86
+ # Tokenize with offset mapping to align chars to tokens
87
+ inputs = self.tokenizer(
88
+ clean_text,
89
+ return_tensors="pt",
90
+ truncation=True,
91
+ padding=True, # Padding not strictly needed for size 1 but good practice
92
+ return_offsets_mapping=True
93
+ ).to(self.device)
94
+
95
+ with torch.no_grad():
96
+ outputs = self.model(
97
+ input_ids=inputs.input_ids,
98
+ attention_mask=inputs.attention_mask
99
+ )
100
+
101
+ # shape: (batch=1, seq_len, hidden_dim)
102
+ last_hidden_state = outputs.last_hidden_state[0]
103
+ offset_mapping = inputs.offset_mapping[0].cpu().numpy()
104
+
105
+ # Identify which tokens correspond to the target spans
106
+ # target_spans is a list of (start_char, end_char)
107
+ # We want to collect ALL tokens that fall within ANY of these spans
108
+
109
+ target_token_indices = []
110
+
111
+ # If pooling is CLS, we just take index 0, UNLESS specific brackets were requested?
112
+ # Requirement: "shows the embedding (CLS, max, or min) only of the part between brackets but in the context of the sentence"
113
+ # This implies if brackets exist, we pool over the TOKENS inside the brackets.
114
+ # If 'cls' is requested for a bracketed segment, it's ambiguous.
115
+ # Usually 'CLS' is for the whole sentence.
116
+ # If user asks for 'CLS' of a segment, maybe they mean 'mean' or it's invalid?
117
+ # However, let's assume if brackets are present:
118
+ # - mean/max: pool over target tokens.
119
+ # - cls: returns the [CLS] token of the WHOLE sentence might be misleading if they asked for specific part.
120
+ # BUT, usually 'CLS' represents the whole sequence.
121
+ # Let's interpret:
122
+ # If brackets present, we ONLY consider tokens inside brackets for mean/max.
123
+ # If CLS is requested with brackets, we might just fall back to MEAN of the brackets, OR return CLS of sentence?
124
+ # The prompt says: "shows the embedding (CLS, max, or min) only of the part between brackets"
125
+ # So for 'cls' it doesn't make sense on a sub-span.
126
+ # I will assume if brackets + CLS -> we just do MEAN of the span (as a reasonable fallback) OR I can treat the first token of the span as 'CLS'-like? No that's hacky.
127
+ # Let's stick to: if brackets exist, we gather those tokens. Then apply pooling.
128
+
129
+ # Find tokens
130
+ for token_idx, (start_offset, end_offset) in enumerate(offset_mapping):
131
+ if start_offset == 0 and end_offset == 0: continue # Special tokens like CLS/SEP often have 0,0 or similar
132
+
133
+ # Check if this token intersects with any target span
134
+ # offset is [start, end)
135
+ # span is [start, end)
136
+
137
+ is_in_target = False
138
+ for span_start, span_end in target_spans:
139
+ # simplistic check: overlap
140
+ # If token is largely inside the span
141
+ if end_offset > span_start and start_offset < span_end:
142
+ is_in_target = True
143
+ break
144
+
145
+ if is_in_target:
146
+ target_token_indices.append(token_idx)
147
+
148
+ # If no tokens found (e.g. brackets were empty or special chars?), fall back to full sentence (ignore CLS/SEP usually?)
149
+ # or if NO brackets were in input, we use full sequence (often excluding CLS/SEP for mean/max)
150
+
151
+ # Check if original had brackets
152
+ has_brackets = (clean_text != sent)
153
+
154
+ if not target_token_indices:
155
+ # No specific target, use all tokens (excluding CLS/SEP for mean/max usually)
156
+ # For BERT, tokens [1:-1] are the real words.
157
+ # If CLS requested, just take [0]
158
+ if pooling == 'cls':
159
+ selected_tokens = last_hidden_state[0:1] # The [CLS]
160
+ else:
161
+ # Use all tokens except CLS(0) and SEP(-1)
162
+ if len(last_hidden_state) > 2:
163
+ selected_tokens = last_hidden_state[1:-1]
164
+ else:
165
+ selected_tokens = last_hidden_state # Fallback
166
+ else:
167
+ # We have specific target tokens
168
+ selected_tokens = last_hidden_state[target_token_indices]
169
+
170
+ # Now Pool
171
+ if len(selected_tokens) == 0:
172
+ # Fallback to zero vector
173
+ embedding = np.zeros(self.model.config.hidden_size)
174
+ else:
175
+ if pooling == 'mean':
176
+ embedding = torch.mean(selected_tokens, dim=0).cpu().numpy()
177
+ elif pooling == 'max':
178
+ embedding = torch.max(selected_tokens, dim=0)[0].cpu().numpy()
179
+ elif pooling == 'cls':
180
+ # If we have brackets, 'cls' is ambiguous.
181
+ # If we selected specific tokens, 'cls' implies 'the representative vector'.
182
+ # Let's just use MEAN for sub-spans if CLS is requested, or if no brackets, use actual CLS.
183
+ if has_brackets:
184
+ embedding = torch.mean(selected_tokens, dim=0).cpu().numpy()
185
+ else:
186
+ # Re-fetch CLS from original if we didn't select it above
187
+ # (Above logic might have skipped it if we fell into 'no target tokens' branch)
188
+ embedding = last_hidden_state[0].cpu().numpy()
189
+ else:
190
+ embedding = torch.mean(selected_tokens, dim=0).cpu().numpy()
191
+
192
+ embeddings_list.append(embedding)
193
+
194
+ return np.vstack(embeddings_list)
195
+
196
+
197
+ class Word2VecEmbedder(BaseEmbedder):
198
+ """
199
+ Wrapper for Word2Vec (using Gensim).
200
+ Since we don't have a local model, we'll try to load a small one or glove-wiki-gigaword-50.
201
+ """
202
+ def __init__(self, model_name: str = "glove-wiki-gigaword-50"):
203
+ print(f"Loading Word2Vec model {model_name}...")
204
+ try:
205
+ self.model = api.load(model_name)
206
+ print(f"Word2Vec model {model_name} loaded.")
207
+ except Exception as e:
208
+ print(f"Failed to load gensim model: {e}")
209
+ self.model = None
210
+
211
+ def _extract_words_and_brackets(self, text: str) -> List[str]:
212
+ """
213
+ Parses text to find words.
214
+ If brackets are present, ONLY returns words inside brackets.
215
+ If no brackets, returns all words.
216
+ """
217
+ # Check for brackets
218
+ targets = re.findall(r'\[(.*?)\]', text)
219
 
220
+ words = []
221
+ if targets:
222
+ # Process only content inside brackets
223
+ # Join them to treat as a stream of text to tokenize?
224
+ # Or just process each group.
225
+ full_target_text = " ".join(targets)
226
+ # Simple tokenization: split by space, remove punctuation
227
+ # Check availability in w2v vocab
228
+ raw_words = re.findall(r'\b\w+\b', full_target_text.lower())
229
+ words = raw_words
230
+ else:
231
+ # All words
232
+ words = re.findall(r'\b\w+\b', text.lower())
233
+
234
+ return words
235
+
236
+ def get_embeddings(self, sentences: List[str], pooling: str = 'cls') -> np.ndarray:
237
+ if self.model is None:
238
+ return np.array([])
239
+
240
+ embeddings_list = []
241
+ vector_size = self.model.vector_size
242
 
243
+ for sent in sentences:
244
+ words = self._extract_words_and_brackets(sent)
245
+
246
+ valid_vectors = []
247
+ for w in words:
248
+ if w in self.model:
249
+ valid_vectors.append(self.model[w])
250
+
251
+ if not valid_vectors:
252
+ embeddings_list.append(np.zeros(vector_size))
253
+ continue
254
+
255
+ vectors_np = np.vstack(valid_vectors)
256
+
257
+ if pooling == 'max':
258
+ emb = np.max(vectors_np, axis=0)
259
+ else:
260
+ # Mean for 'mean' and 'cls' (w2v has no CLS)
261
+ emb = np.mean(vectors_np, axis=0)
262
+
263
+ embeddings_list.append(emb)
264
+
265
+ return np.vstack(embeddings_list)
266
+
267
+ # Factory/Container
268
+ class ModelManager:
269
+ def __init__(self):
270
+ self.models = {}
271
+
272
+ def get_model(self, model_type: str):
273
+ if model_type not in self.models:
274
+ if model_type == 'clinical_bert':
275
+ self.models[model_type] = BertEmbedder("emilyalsentzer/Bio_ClinicalBERT")
276
+ elif model_type == 'bert':
277
+ self.models[model_type] = BertEmbedder("bert-base-uncased")
278
+ elif model_type == 'word2vec':
279
+ self.models[model_type] = Word2VecEmbedder()
280
+ else:
281
+ raise ValueError(f"Unknown model type: {model_type}")
282
+ return self.models[model_type]
app/server_clinical_embedding.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import List
2
  from fastapi import FastAPI, Query, UploadFile, File, HTTPException
3
  from fastapi.responses import RedirectResponse
4
  from fastapi.responses import StreamingResponse
@@ -11,25 +11,25 @@ import io
11
  import csv
12
  import os
13
 
14
- from clinical_embedding import ClinicalBERT
15
 
16
  # Pydantic models for request/response
17
  class EmbeddingRequest(BaseModel):
18
  sentences: List[str]
19
  pooling: str = 'cls'
20
-
21
 
22
  class EmbeddingResponse(BaseModel):
23
  embeddings: List[List[float]]
24
  shape: List[int]
25
  pooling: str
26
-
27
 
28
  # Initialize FastAPI app
29
  app = FastAPI(
30
- title="Clinical BERT Embeddings API",
31
- description="API for generating embeddings using Bio_ClinicalBERT model",
32
- version="1.0.0"
33
  )
34
 
35
  # Add CORS middleware to allow web page access
@@ -44,16 +44,17 @@ app.add_middleware(
44
  # Serve static files
45
  app.mount("/app/static", StaticFiles(directory="static"), name="static")
46
 
47
- # Initialize model (global instance)
48
- clinical_bert = None
49
-
50
 
51
  @app.on_event("startup")
52
  async def startup_event():
53
- """Load model on startup"""
54
- global clinical_bert
55
- clinical_bert = ClinicalBERT(device=-1) # Use device=0 for GPU
56
-
 
 
57
 
58
  @app.get("/")
59
  async def root():
@@ -61,33 +62,29 @@ async def root():
61
 
62
  @app.get("/browser/")
63
  def get_browser():
64
- print(os.path.join("static", "browser", "index.html"))
65
  return FileResponse(os.path.join("static", "browser", "index.html"))
66
 
67
-
68
  @app.get("/embeddings", response_model=EmbeddingResponse)
69
  async def get_embeddings(
70
  sentences: List[str] = Query(..., description="List of sentences to embed"),
71
- pooling: str = Query('cls', description="Pooling strategy: mean, cls, or max")
 
72
  ):
73
  """
74
  Generate embeddings for a list of sentences.
75
-
76
- Args:
77
- sentences: List of input sentences
78
- pooling: Pooling strategy ('mean', 'cls', or 'max')
79
-
80
- Returns:
81
- EmbeddingResponse with embeddings and metadata
82
  """
83
  # Validate pooling method
84
  if pooling not in ['mean', 'cls', 'max']:
85
- return {
86
- "error": "Invalid pooling method. Choose from: mean, cls, max"
87
- }
 
 
 
88
 
89
  # Generate embeddings
90
- embeddings = clinical_bert.get_embeddings(sentences, pooling=pooling)
91
 
92
  # Convert to list for JSON serialization
93
  embeddings_list = embeddings.tolist()
@@ -95,36 +92,34 @@ async def get_embeddings(
95
  return EmbeddingResponse(
96
  embeddings=embeddings_list,
97
  shape=list(embeddings.shape),
98
- pooling=pooling
 
99
  )
100
 
101
-
102
  @app.get("/health")
103
  async def health_check():
104
  """Health check endpoint"""
105
  return {
106
  "status": "healthy",
107
- "model_loaded": clinical_bert is not None
108
  }
109
 
110
-
111
  @app.post("/embeddings/batch")
112
  async def post_embeddings_batch(request: EmbeddingRequest):
113
  """
114
  POST endpoint for batch embedding generation.
115
-
116
- Args:
117
- request: EmbeddingRequest with sentences and pooling method
118
-
119
- Returns:
120
- EmbeddingResponse with embeddings and metadata
121
  """
122
  # Validate pooling method
123
  if request.pooling not in ['mean', 'cls', 'max']:
124
  raise HTTPException(status_code=400, detail="Invalid pooling method. Choose from: mean, cls, max")
125
 
 
 
 
 
 
126
  # Generate embeddings
127
- embeddings = clinical_bert.get_embeddings(request.sentences, pooling=request.pooling)
128
 
129
  # Convert to list for JSON serialization
130
  embeddings_list = embeddings.tolist()
@@ -132,24 +127,18 @@ async def post_embeddings_batch(request: EmbeddingRequest):
132
  return EmbeddingResponse(
133
  embeddings=embeddings_list,
134
  shape=list(embeddings.shape),
135
- pooling=request.pooling
 
136
  )
137
 
138
-
139
  @app.post("/embeddings/file")
140
  async def upload_file_embeddings(
141
  file: UploadFile = File(...),
142
- pooling: str = Query('cls', description="Pooling strategy: mean, cls, or max")
 
143
  ):
144
  """
145
  Upload a CSV file with terms and get embeddings back as CSV.
146
-
147
- Args:
148
- file: CSV file with one column containing terms
149
- pooling: Pooling strategy ('mean', 'cls', or 'max')
150
-
151
- Returns:
152
- CSV file with embeddings
153
  """
154
  # Validate file type
155
  if not file.filename.endswith('.csv'):
@@ -159,6 +148,11 @@ async def upload_file_embeddings(
159
  if pooling not in ['mean', 'cls', 'max']:
160
  raise HTTPException(status_code=400, detail="Invalid pooling method. Choose from: mean, cls, max")
161
 
 
 
 
 
 
162
  try:
163
  # Read CSV file
164
  contents = await file.read()
@@ -178,7 +172,7 @@ async def upload_file_embeddings(
178
  raise HTTPException(status_code=400, detail="No terms found in CSV")
179
 
180
  # Generate embeddings
181
- embeddings = clinical_bert.get_embeddings(terms, pooling=pooling)
182
 
183
  # Create output CSV
184
  output = io.StringIO()
@@ -206,11 +200,10 @@ async def upload_file_embeddings(
206
  except Exception as e:
207
  raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
208
 
209
-
210
  if __name__ == "__main__":
211
  # Run the server
212
  uvicorn.run(
213
- "main:app",
214
  host="0.0.0.0",
215
  port=8000,
216
  reload=False
 
1
+ from typing import List, Optional
2
  from fastapi import FastAPI, Query, UploadFile, File, HTTPException
3
  from fastapi.responses import RedirectResponse
4
  from fastapi.responses import StreamingResponse
 
11
  import csv
12
  import os
13
 
14
+ from clinical_embedding import ModelManager
15
 
16
  # Pydantic models for request/response
17
  class EmbeddingRequest(BaseModel):
18
  sentences: List[str]
19
  pooling: str = 'cls'
20
+ model: str = 'clinical_bert'
21
 
22
  class EmbeddingResponse(BaseModel):
23
  embeddings: List[List[float]]
24
  shape: List[int]
25
  pooling: str
26
+ model: str
27
 
28
  # Initialize FastAPI app
29
  app = FastAPI(
30
+ title="Clinical Embedding API",
31
+ description="API for generating embeddings using various models (ClinicalBERT, BERT, Word2Vec)",
32
+ version="2.0.0"
33
  )
34
 
35
  # Add CORS middleware to allow web page access
 
44
  # Serve static files
45
  app.mount("/app/static", StaticFiles(directory="static"), name="static")
46
 
47
+ # Initialize model manager (global instance)
48
+ model_manager = ModelManager()
 
49
 
50
  @app.on_event("startup")
51
  async def startup_event():
52
+ """
53
+ Load default model on startup.
54
+ Other models will be loaded on demand (see ModelManager).
55
+ """
56
+ # Pre-load ClinicalBERT as it's the default
57
+ model_manager.get_model('clinical_bert')
58
 
59
  @app.get("/")
60
  async def root():
 
62
 
63
  @app.get("/browser/")
64
  def get_browser():
 
65
  return FileResponse(os.path.join("static", "browser", "index.html"))
66
 
 
67
  @app.get("/embeddings", response_model=EmbeddingResponse)
68
  async def get_embeddings(
69
  sentences: List[str] = Query(..., description="List of sentences to embed"),
70
+ pooling: str = Query('cls', description="Pooling strategy: mean, cls, or max"),
71
+ model: str = Query('clinical_bert', description="Model to use: clinical_bert, bert, word2vec")
72
  ):
73
  """
74
  Generate embeddings for a list of sentences.
75
+ Supports bracketed text for context-aware specific extraction.
 
 
 
 
 
 
76
  """
77
  # Validate pooling method
78
  if pooling not in ['mean', 'cls', 'max']:
79
+ raise HTTPException(status_code=400, detail="Invalid pooling method. Choose from: mean, cls, max")
80
+
81
+ try:
82
+ embedder = model_manager.get_model(model)
83
+ except ValueError as e:
84
+ raise HTTPException(status_code=400, detail=str(e))
85
 
86
  # Generate embeddings
87
+ embeddings = embedder.get_embeddings(sentences, pooling=pooling)
88
 
89
  # Convert to list for JSON serialization
90
  embeddings_list = embeddings.tolist()
 
92
  return EmbeddingResponse(
93
  embeddings=embeddings_list,
94
  shape=list(embeddings.shape),
95
+ pooling=pooling,
96
+ model=model
97
  )
98
 
 
99
  @app.get("/health")
100
  async def health_check():
101
  """Health check endpoint"""
102
  return {
103
  "status": "healthy",
104
+ "loaded_models": list(model_manager.models.keys())
105
  }
106
 
 
107
  @app.post("/embeddings/batch")
108
  async def post_embeddings_batch(request: EmbeddingRequest):
109
  """
110
  POST endpoint for batch embedding generation.
 
 
 
 
 
 
111
  """
112
  # Validate pooling method
113
  if request.pooling not in ['mean', 'cls', 'max']:
114
  raise HTTPException(status_code=400, detail="Invalid pooling method. Choose from: mean, cls, max")
115
 
116
+ try:
117
+ embedder = model_manager.get_model(request.model)
118
+ except ValueError as e:
119
+ raise HTTPException(status_code=400, detail=str(e))
120
+
121
  # Generate embeddings
122
+ embeddings = embedder.get_embeddings(request.sentences, pooling=request.pooling)
123
 
124
  # Convert to list for JSON serialization
125
  embeddings_list = embeddings.tolist()
 
127
  return EmbeddingResponse(
128
  embeddings=embeddings_list,
129
  shape=list(embeddings.shape),
130
+ pooling=request.pooling,
131
+ model=request.model
132
  )
133
 
 
134
  @app.post("/embeddings/file")
135
  async def upload_file_embeddings(
136
  file: UploadFile = File(...),
137
+ pooling: str = Query('cls', description="Pooling strategy: mean, cls, or max"),
138
+ model: str = Query('clinical_bert', description="Model to use: clinical_bert, bert, word2vec")
139
  ):
140
  """
141
  Upload a CSV file with terms and get embeddings back as CSV.
 
 
 
 
 
 
 
142
  """
143
  # Validate file type
144
  if not file.filename.endswith('.csv'):
 
148
  if pooling not in ['mean', 'cls', 'max']:
149
  raise HTTPException(status_code=400, detail="Invalid pooling method. Choose from: mean, cls, max")
150
 
151
+ try:
152
+ embedder = model_manager.get_model(model)
153
+ except ValueError as e:
154
+ raise HTTPException(status_code=400, detail=str(e))
155
+
156
  try:
157
  # Read CSV file
158
  contents = await file.read()
 
172
  raise HTTPException(status_code=400, detail="No terms found in CSV")
173
 
174
  # Generate embeddings
175
+ embeddings = embedder.get_embeddings(terms, pooling=pooling)
176
 
177
  # Create output CSV
178
  output = io.StringIO()
 
200
  except Exception as e:
201
  raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
202
 
 
203
  if __name__ == "__main__":
204
  # Run the server
205
  uvicorn.run(
206
+ "server_clinical_embedding:app",
207
  host="0.0.0.0",
208
  port=8000,
209
  reload=False
app/static/browser/index.html CHANGED
@@ -1,5 +1,6 @@
1
  <!DOCTYPE html>
2
  <html lang="en">
 
3
  <head>
4
  <meta charset="UTF-8">
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
@@ -10,33 +11,33 @@
10
  padding: 0;
11
  box-sizing: border-box;
12
  }
13
-
14
  body {
15
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
16
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
17
  min-height: 100vh;
18
  padding: 20px;
19
  }
20
-
21
  .container {
22
  max-width: 1200px;
23
  margin: 0 auto;
24
  }
25
-
26
  h1 {
27
  color: white;
28
  text-align: center;
29
  margin-bottom: 30px;
30
  font-size: 2.5em;
31
- text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
32
  }
33
-
34
  .tabs {
35
  display: flex;
36
  gap: 10px;
37
  margin-bottom: 20px;
38
  }
39
-
40
  .tab-button {
41
  flex: 1;
42
  padding: 15px;
@@ -49,41 +50,43 @@
49
  transition: all 0.3s;
50
  color: #667eea;
51
  }
52
-
53
  .tab-button:hover {
54
  background: #f0f0f0;
55
  }
56
-
57
  .tab-button.active {
58
  background: white;
59
  color: #764ba2;
60
- box-shadow: 0 -2px 10px rgba(0,0,0,0.1);
61
  }
62
-
63
  .tab-content {
64
  display: none;
65
  background: white;
66
  padding: 30px;
67
  border-radius: 0 0 12px 12px;
68
- box-shadow: 0 4px 6px rgba(0,0,0,0.1);
69
  }
70
-
71
  .tab-content.active {
72
  display: block;
73
  }
74
-
75
  .form-group {
76
  margin-bottom: 20px;
77
  }
78
-
79
  label {
80
  display: block;
81
  margin-bottom: 8px;
82
  font-weight: bold;
83
  color: #333;
84
  }
85
-
86
- textarea, input[type="file"], select {
 
 
87
  width: 100%;
88
  padding: 12px;
89
  border: 2px solid #e0e0e0;
@@ -92,17 +95,18 @@
92
  font-family: 'Courier New', monospace;
93
  transition: border-color 0.3s;
94
  }
95
-
96
- textarea:focus, select:focus {
 
97
  outline: none;
98
  border-color: #667eea;
99
  }
100
-
101
  textarea {
102
  min-height: 150px;
103
  resize: vertical;
104
  }
105
-
106
  button {
107
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
108
  color: white;
@@ -114,22 +118,22 @@
114
  cursor: pointer;
115
  transition: transform 0.2s, box-shadow 0.2s;
116
  }
117
-
118
  button:hover {
119
  transform: translateY(-2px);
120
  box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
121
  }
122
-
123
  button:active {
124
  transform: translateY(0);
125
  }
126
-
127
  button:disabled {
128
  background: #ccc;
129
  cursor: not-allowed;
130
  transform: none;
131
  }
132
-
133
  .loading {
134
  display: none;
135
  text-align: center;
@@ -137,11 +141,11 @@
137
  color: #667eea;
138
  font-weight: bold;
139
  }
140
-
141
  .loading.show {
142
  display: block;
143
  }
144
-
145
  .spinner {
146
  border: 4px solid #f3f3f3;
147
  border-top: 4px solid #667eea;
@@ -151,12 +155,17 @@
151
  animation: spin 1s linear infinite;
152
  margin: 0 auto 10px;
153
  }
154
-
155
  @keyframes spin {
156
- 0% { transform: rotate(0deg); }
157
- 100% { transform: rotate(360deg); }
 
 
 
 
 
158
  }
159
-
160
  .error {
161
  background: #fee;
162
  color: #c33;
@@ -166,11 +175,11 @@
166
  border-left: 4px solid #c33;
167
  display: none;
168
  }
169
-
170
  .error.show {
171
  display: block;
172
  }
173
-
174
  .success {
175
  background: #efe;
176
  color: #3c3;
@@ -180,11 +189,11 @@
180
  border-left: 4px solid #3c3;
181
  display: none;
182
  }
183
-
184
  .success.show {
185
  display: block;
186
  }
187
-
188
  .info {
189
  background: #e3f2fd;
190
  padding: 15px;
@@ -193,7 +202,7 @@
193
  color: #1976d2;
194
  border-left: 4px solid #1976d2;
195
  }
196
-
197
  .download-section {
198
  display: none;
199
  margin-top: 20px;
@@ -202,50 +211,51 @@
202
  border-radius: 6px;
203
  text-align: center;
204
  }
205
-
206
  .download-section.show {
207
  display: block;
208
  }
209
-
210
  .settings {
211
  display: flex;
212
  gap: 20px;
213
  align-items: end;
214
  }
215
-
216
  .settings .form-group {
217
  flex: 1;
218
  }
219
  </style>
220
  </head>
 
221
  <body>
222
  <div class="container">
223
  <h1>🧬 Clinical BERT Embeddings</h1>
224
-
225
  <div class="tabs">
226
  <button class="tab-button active" onclick="switchTab('inline')">📝 Inline Embeddings</button>
227
  <button class="tab-button" onclick="switchTab('file')">📁 File Embeddings</button>
228
  </div>
229
-
230
  <!-- Inline Embeddings Tab -->
231
  <div id="inline-tab" class="tab-content active">
232
  <div class="info">
233
  💡 Enter medical terms separated by commas or new lines. Example: Heart Attack, Myocardial Infarction
234
  </div>
235
-
236
  <div class="error" id="inline-error"></div>
237
  <div class="success" id="inline-success"></div>
238
-
239
  <div class="settings">
240
  <div class="form-group" style="flex: 3;">
241
  <label for="inline-terms">Medical Terms:</label>
242
  <textarea id="inline-terms" placeholder="Enter terms here (comma or newline separated)...
243
  Example:
244
- Heart Attack
245
- Myocardial Infarction
246
  Diabetes"></textarea>
247
  </div>
248
-
249
  <div class="form-group">
250
  <label for="inline-pooling">Pooling:</label>
251
  <select id="inline-pooling">
@@ -254,36 +264,68 @@ Diabetes"></textarea>
254
  <option value="max">Max</option>
255
  </select>
256
  </div>
 
 
257
  </div>
258
-
259
- <button onclick="getInlineEmbeddings()" id="inline-btn">Generate Embeddings</button>
260
-
261
  <div class="loading" id="inline-loading">
262
  <div class="spinner"></div>
263
- Processing...
264
  </div>
265
-
266
- <div class="form-group" style="margin-top: 20px;">
267
- <label for="inline-results">Embeddings (JSON):</label>
268
- <textarea id="inline-results" readonly placeholder="Results will appear here..."></textarea>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  </div>
270
  </div>
271
-
272
  <!-- File Embeddings Tab -->
273
  <div id="file-tab" class="tab-content">
274
  <div class="info">
275
  💡 Upload a CSV file with one column containing medical terms. The first row should be the column name.
276
  </div>
277
-
278
  <div class="error" id="file-error"></div>
279
  <div class="success" id="file-success"></div>
280
-
281
  <div class="settings">
282
- <div class="form-group" style="flex: 3;">
283
  <label for="file-input">Select CSV File:</label>
284
  <input type="file" id="file-input" accept=".csv">
285
  </div>
286
-
287
  <div class="form-group">
288
  <label for="file-pooling">Pooling:</label>
289
  <select id="file-pooling">
@@ -292,15 +334,24 @@ Diabetes"></textarea>
292
  <option value="max">Max</option>
293
  </select>
294
  </div>
 
 
 
 
 
 
 
 
 
295
  </div>
296
-
297
  <button onclick="uploadFileEmbeddings()" id="file-btn">Process File</button>
298
-
299
  <div class="loading" id="file-loading">
300
  <div class="spinner"></div>
301
  Processing file...
302
  </div>
303
-
304
  <div class="download-section" id="download-section">
305
  <h3>✅ Embeddings Ready!</h3>
306
  <p style="margin: 10px 0;">Your embeddings have been generated successfully.</p>
@@ -308,133 +359,201 @@ Diabetes"></textarea>
308
  </div>
309
  </div>
310
  </div>
311
-
312
  <script>
313
  const API_URL = 'https://santanche-clinical-embedding.hf.space';
314
  let downloadBlob = null;
315
  let downloadFilename = null;
316
-
317
  function switchTab(tab) {
318
  // Update tab buttons
319
  document.querySelectorAll('.tab-button').forEach(btn => {
320
  btn.classList.remove('active');
321
  });
322
  event.target.classList.add('active');
323
-
324
  // Update tab content
325
  document.querySelectorAll('.tab-content').forEach(content => {
326
  content.classList.remove('active');
327
  });
328
  document.getElementById(`${tab}-tab`).classList.add('active');
329
  }
330
-
331
  function showError(tabId, message) {
332
  const errorDiv = document.getElementById(`${tabId}-error`);
333
  errorDiv.textContent = message;
334
  errorDiv.classList.add('show');
335
  setTimeout(() => errorDiv.classList.remove('show'), 5000);
336
  }
337
-
338
  function showSuccess(tabId, message) {
339
  const successDiv = document.getElementById(`${tabId}-success`);
340
  successDiv.textContent = message;
341
  successDiv.classList.add('show');
342
  setTimeout(() => successDiv.classList.remove('show'), 5000);
343
  }
344
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  async function getInlineEmbeddings() {
346
  const termsText = document.getElementById('inline-terms').value.trim();
347
  const pooling = document.getElementById('inline-pooling').value;
348
- const resultsArea = document.getElementById('inline-results');
349
  const loadingDiv = document.getElementById('inline-loading');
350
  const btn = document.getElementById('inline-btn');
351
-
 
352
  if (!termsText) {
353
  showError('inline', 'Please enter some terms');
354
  return;
355
  }
356
-
357
- // Parse terms (split by comma or newline)
358
  const terms = termsText
359
- .split(/[,\n]+/)
360
  .map(t => t.trim())
361
  .filter(t => t.length > 0);
362
-
363
  if (terms.length === 0) {
364
  showError('inline', 'No valid terms found');
365
  return;
366
  }
367
-
368
  // Show loading
369
  loadingDiv.classList.add('show');
370
  btn.disabled = true;
371
- resultsArea.value = '';
372
-
 
 
 
 
 
 
373
  try {
374
- const response = await fetch(`${API_URL}/embeddings/batch`, {
375
- method: 'POST',
376
- headers: {
377
- 'Content-Type': 'application/json',
378
- },
379
- body: JSON.stringify({
380
- sentences: terms,
381
- pooling: pooling
382
- })
 
 
 
 
 
 
 
 
 
 
383
  });
384
-
385
- if (!response.ok) {
386
- throw new Error(`HTTP error! status: ${response.status}`);
387
- }
388
-
389
- const data = await response.json();
390
- resultsArea.value = JSON.stringify(data, null, 2);
391
- showSuccess('inline', `Generated embeddings for ${terms.length} terms (shape: ${data.shape})`);
392
  } catch (error) {
393
- showError('inline', `Error: ${error.message}`);
394
- resultsArea.value = '';
395
  } finally {
396
  loadingDiv.classList.remove('show');
397
  btn.disabled = false;
398
  }
399
  }
400
-
401
  async function uploadFileEmbeddings() {
402
  const fileInput = document.getElementById('file-input');
403
  const pooling = document.getElementById('file-pooling').value;
 
404
  const loadingDiv = document.getElementById('file-loading');
405
  const btn = document.getElementById('file-btn');
406
  const downloadSection = document.getElementById('download-section');
407
-
408
  if (!fileInput.files || fileInput.files.length === 0) {
409
  showError('file', 'Please select a CSV file');
410
  return;
411
  }
412
-
413
  const file = fileInput.files[0];
414
-
415
  // Show loading
416
  loadingDiv.classList.add('show');
417
  btn.disabled = true;
418
  downloadSection.classList.remove('show');
419
-
420
  try {
421
  const formData = new FormData();
422
  formData.append('file', file);
423
-
424
- const response = await fetch(`${API_URL}/embeddings/file?pooling=${pooling}`, {
425
  method: 'POST',
426
  body: formData
427
  });
428
-
429
  if (!response.ok) {
430
  const errorData = await response.json();
431
  throw new Error(errorData.detail || `HTTP error! status: ${response.status}`);
432
  }
433
-
434
  // Get the blob for download
435
  downloadBlob = await response.blob();
436
  downloadFilename = `embeddings_${file.name}`;
437
-
438
  // Show download section
439
  downloadSection.classList.add('show');
440
  showSuccess('file', 'File processed successfully!');
@@ -446,13 +565,13 @@ Diabetes"></textarea>
446
  btn.disabled = false;
447
  }
448
  }
449
-
450
  function downloadResults() {
451
  if (!downloadBlob) {
452
  showError('file', 'No data to download');
453
  return;
454
  }
455
-
456
  const url = window.URL.createObjectURL(downloadBlob);
457
  const a = document.createElement('a');
458
  a.href = url;
@@ -464,4 +583,5 @@ Diabetes"></textarea>
464
  }
465
  </script>
466
  </body>
467
- </html>
 
 
1
  <!DOCTYPE html>
2
  <html lang="en">
3
+
4
  <head>
5
  <meta charset="UTF-8">
6
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
 
11
  padding: 0;
12
  box-sizing: border-box;
13
  }
14
+
15
  body {
16
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
17
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
18
  min-height: 100vh;
19
  padding: 20px;
20
  }
21
+
22
  .container {
23
  max-width: 1200px;
24
  margin: 0 auto;
25
  }
26
+
27
  h1 {
28
  color: white;
29
  text-align: center;
30
  margin-bottom: 30px;
31
  font-size: 2.5em;
32
+ text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.2);
33
  }
34
+
35
  .tabs {
36
  display: flex;
37
  gap: 10px;
38
  margin-bottom: 20px;
39
  }
40
+
41
  .tab-button {
42
  flex: 1;
43
  padding: 15px;
 
50
  transition: all 0.3s;
51
  color: #667eea;
52
  }
53
+
54
  .tab-button:hover {
55
  background: #f0f0f0;
56
  }
57
+
58
  .tab-button.active {
59
  background: white;
60
  color: #764ba2;
61
+ box-shadow: 0 -2px 10px rgba(0, 0, 0, 0.1);
62
  }
63
+
64
  .tab-content {
65
  display: none;
66
  background: white;
67
  padding: 30px;
68
  border-radius: 0 0 12px 12px;
69
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
70
  }
71
+
72
  .tab-content.active {
73
  display: block;
74
  }
75
+
76
  .form-group {
77
  margin-bottom: 20px;
78
  }
79
+
80
  label {
81
  display: block;
82
  margin-bottom: 8px;
83
  font-weight: bold;
84
  color: #333;
85
  }
86
+
87
+ textarea,
88
+ input[type="file"],
89
+ select {
90
  width: 100%;
91
  padding: 12px;
92
  border: 2px solid #e0e0e0;
 
95
  font-family: 'Courier New', monospace;
96
  transition: border-color 0.3s;
97
  }
98
+
99
+ textarea:focus,
100
+ select:focus {
101
  outline: none;
102
  border-color: #667eea;
103
  }
104
+
105
  textarea {
106
  min-height: 150px;
107
  resize: vertical;
108
  }
109
+
110
  button {
111
  background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
112
  color: white;
 
118
  cursor: pointer;
119
  transition: transform 0.2s, box-shadow 0.2s;
120
  }
121
+
122
  button:hover {
123
  transform: translateY(-2px);
124
  box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
125
  }
126
+
127
  button:active {
128
  transform: translateY(0);
129
  }
130
+
131
  button:disabled {
132
  background: #ccc;
133
  cursor: not-allowed;
134
  transform: none;
135
  }
136
+
137
  .loading {
138
  display: none;
139
  text-align: center;
 
141
  color: #667eea;
142
  font-weight: bold;
143
  }
144
+
145
  .loading.show {
146
  display: block;
147
  }
148
+
149
  .spinner {
150
  border: 4px solid #f3f3f3;
151
  border-top: 4px solid #667eea;
 
155
  animation: spin 1s linear infinite;
156
  margin: 0 auto 10px;
157
  }
158
+
159
  @keyframes spin {
160
+ 0% {
161
+ transform: rotate(0deg);
162
+ }
163
+
164
+ 100% {
165
+ transform: rotate(360deg);
166
+ }
167
  }
168
+
169
  .error {
170
  background: #fee;
171
  color: #c33;
 
175
  border-left: 4px solid #c33;
176
  display: none;
177
  }
178
+
179
  .error.show {
180
  display: block;
181
  }
182
+
183
  .success {
184
  background: #efe;
185
  color: #3c3;
 
189
  border-left: 4px solid #3c3;
190
  display: none;
191
  }
192
+
193
  .success.show {
194
  display: block;
195
  }
196
+
197
  .info {
198
  background: #e3f2fd;
199
  padding: 15px;
 
202
  color: #1976d2;
203
  border-left: 4px solid #1976d2;
204
  }
205
+
206
  .download-section {
207
  display: none;
208
  margin-top: 20px;
 
211
  border-radius: 6px;
212
  text-align: center;
213
  }
214
+
215
  .download-section.show {
216
  display: block;
217
  }
218
+
219
  .settings {
220
  display: flex;
221
  gap: 20px;
222
  align-items: end;
223
  }
224
+
225
  .settings .form-group {
226
  flex: 1;
227
  }
228
  </style>
229
  </head>
230
+
231
  <body>
232
  <div class="container">
233
  <h1>🧬 Clinical BERT Embeddings</h1>
234
+
235
  <div class="tabs">
236
  <button class="tab-button active" onclick="switchTab('inline')">📝 Inline Embeddings</button>
237
  <button class="tab-button" onclick="switchTab('file')">📁 File Embeddings</button>
238
  </div>
239
+
240
  <!-- Inline Embeddings Tab -->
241
  <div id="inline-tab" class="tab-content active">
242
  <div class="info">
243
  💡 Enter medical terms separated by commas or new lines. Example: Heart Attack, Myocardial Infarction
244
  </div>
245
+
246
  <div class="error" id="inline-error"></div>
247
  <div class="success" id="inline-success"></div>
248
+
249
  <div class="settings">
250
  <div class="form-group" style="flex: 3;">
251
  <label for="inline-terms">Medical Terms:</label>
252
  <textarea id="inline-terms" placeholder="Enter terms here (comma or newline separated)...
253
  Example:
254
+ The patient had a [heart attack] yesterday.
255
+ [Myocardial Infarction] is serious.
256
  Diabetes"></textarea>
257
  </div>
258
+
259
  <div class="form-group">
260
  <label for="inline-pooling">Pooling:</label>
261
  <select id="inline-pooling">
 
264
  <option value="max">Max</option>
265
  </select>
266
  </div>
267
+
268
+ <!-- Model selector removed for Inline tab -->
269
  </div>
270
+
271
+ <button onclick="getInlineEmbeddings()" id="inline-btn">Generate Embeddings (All Models)</button>
272
+
273
  <div class="loading" id="inline-loading">
274
  <div class="spinner"></div>
275
+ Processing 3 models...
276
  </div>
277
+
278
+ <div id="results-container" style="display: none; margin-top: 20px;">
279
+ <!-- Clinical BERT -->
280
+ <div class="result-block">
281
+ <h3
282
+ style="color: #667eea; border-bottom: 2px solid #667eea; padding-bottom: 5px; margin-bottom: 10px;">
283
+ 🧬 Clinical BERT</h3>
284
+ <label>Visualization:</label>
285
+ <div id="viz-clinical_bert" style="margin-bottom: 15px;"></div>
286
+ <label>JSON:</label>
287
+ <textarea id="json-clinical_bert" readonly style="height: 100px;"></textarea>
288
+ </div>
289
+
290
+ <!-- Standard BERT -->
291
+ <div class="result-block" style="margin-top: 30px;">
292
+ <h3
293
+ style="color: #764ba2; border-bottom: 2px solid #764ba2; padding-bottom: 5px; margin-bottom: 10px;">
294
+ 🤖 Standard BERT</h3>
295
+ <label>Visualization:</label>
296
+ <div id="viz-bert" style="margin-bottom: 15px;"></div>
297
+ <label>JSON:</label>
298
+ <textarea id="json-bert" readonly style="height: 100px;"></textarea>
299
+ </div>
300
+
301
+ <!-- Word2Vec -->
302
+ <div class="result-block" style="margin-top: 30px;">
303
+ <h3
304
+ style="color: #2c3e50; border-bottom: 2px solid #2c3e50; padding-bottom: 5px; margin-bottom: 10px;">
305
+ 📚 Word2Vec</h3>
306
+ <label>Visualization:</label>
307
+ <div id="viz-word2vec" style="margin-bottom: 15px;"></div>
308
+ <label>JSON:</label>
309
+ <textarea id="json-word2vec" readonly style="height: 100px;"></textarea>
310
+ </div>
311
  </div>
312
  </div>
313
+
314
  <!-- File Embeddings Tab -->
315
  <div id="file-tab" class="tab-content">
316
  <div class="info">
317
  💡 Upload a CSV file with one column containing medical terms. The first row should be the column name.
318
  </div>
319
+
320
  <div class="error" id="file-error"></div>
321
  <div class="success" id="file-success"></div>
322
+
323
  <div class="settings">
324
+ <div class="form-group" style="flex: 2;">
325
  <label for="file-input">Select CSV File:</label>
326
  <input type="file" id="file-input" accept=".csv">
327
  </div>
328
+
329
  <div class="form-group">
330
  <label for="file-pooling">Pooling:</label>
331
  <select id="file-pooling">
 
334
  <option value="max">Max</option>
335
  </select>
336
  </div>
337
+
338
+ <div class="form-group">
339
+ <label for="file-model">Model:</label>
340
+ <select id="file-model">
341
+ <option value="clinical_bert" selected>Clinical BERT</option>
342
+ <option value="bert">Standard BERT</option>
343
+ <option value="word2vec">Word2Vec</option>
344
+ </select>
345
+ </div>
346
  </div>
347
+
348
  <button onclick="uploadFileEmbeddings()" id="file-btn">Process File</button>
349
+
350
  <div class="loading" id="file-loading">
351
  <div class="spinner"></div>
352
  Processing file...
353
  </div>
354
+
355
  <div class="download-section" id="download-section">
356
  <h3>✅ Embeddings Ready!</h3>
357
  <p style="margin: 10px 0;">Your embeddings have been generated successfully.</p>
 
359
  </div>
360
  </div>
361
  </div>
362
+
363
  <script>
364
  const API_URL = 'https://santanche-clinical-embedding.hf.space';
365
  let downloadBlob = null;
366
  let downloadFilename = null;
367
+
368
  function switchTab(tab) {
369
  // Update tab buttons
370
  document.querySelectorAll('.tab-button').forEach(btn => {
371
  btn.classList.remove('active');
372
  });
373
  event.target.classList.add('active');
374
+
375
  // Update tab content
376
  document.querySelectorAll('.tab-content').forEach(content => {
377
  content.classList.remove('active');
378
  });
379
  document.getElementById(`${tab}-tab`).classList.add('active');
380
  }
381
+
382
  function showError(tabId, message) {
383
  const errorDiv = document.getElementById(`${tabId}-error`);
384
  errorDiv.textContent = message;
385
  errorDiv.classList.add('show');
386
  setTimeout(() => errorDiv.classList.remove('show'), 5000);
387
  }
388
+
389
  function showSuccess(tabId, message) {
390
  const successDiv = document.getElementById(`${tabId}-success`);
391
  successDiv.textContent = message;
392
  successDiv.classList.add('show');
393
  setTimeout(() => successDiv.classList.remove('show'), 5000);
394
  }
395
+
396
+ function createHeatmap(data, sentences, containerId) {
397
+ const container = document.getElementById(containerId);
398
+ container.innerHTML = '';
399
+
400
+ data.forEach((embedding, index) => {
401
+ const sentence = sentences[index];
402
+ const rowTitle = document.createElement('div');
403
+ rowTitle.style.fontWeight = 'bold';
404
+ rowTitle.style.fontSize = '0.9em';
405
+ rowTitle.style.marginBottom = '2px';
406
+ rowTitle.style.marginTop = '10px';
407
+ rowTitle.textContent = `${index + 1}. ${sentence}`;
408
+ container.appendChild(rowTitle);
409
+
410
+ const row = document.createElement('div');
411
+ row.style.display = 'flex';
412
+ row.style.flexWrap = 'wrap';
413
+ row.style.gap = '1px';
414
+ row.style.maxWidth = '100%';
415
+
416
+ embedding.forEach(val => {
417
+ const block = document.createElement('div');
418
+ block.style.width = '8px';
419
+ block.style.height = '12px';
420
+ block.title = val.toFixed(4);
421
+
422
+ const intensity = Math.min(Math.abs(val) * 2, 1);
423
+
424
+ if (val > 0) {
425
+ block.style.backgroundColor = `rgba(0, 0, 255, ${intensity})`;
426
+ } else {
427
+ block.style.backgroundColor = `rgba(255, 0, 0, ${intensity})`;
428
+ }
429
+
430
+ row.appendChild(block);
431
+ });
432
+ container.appendChild(row);
433
+ });
434
+ }
435
+
436
+ async function fetchModelEmbeddings(modelName, terms, pooling) {
437
+ const response = await fetch(`${API_URL}/embeddings/batch`, {
438
+ method: 'POST',
439
+ headers: { 'Content-Type': 'application/json' },
440
+ body: JSON.stringify({
441
+ sentences: terms,
442
+ pooling: pooling,
443
+ model: modelName
444
+ })
445
+ });
446
+
447
+ if (!response.ok) {
448
+ const err = await response.json();
449
+ throw new Error(err.detail || `HTTP error ${response.status}`);
450
+ }
451
+ return await response.json();
452
+ }
453
+
454
  async function getInlineEmbeddings() {
455
  const termsText = document.getElementById('inline-terms').value.trim();
456
  const pooling = document.getElementById('inline-pooling').value;
 
457
  const loadingDiv = document.getElementById('inline-loading');
458
  const btn = document.getElementById('inline-btn');
459
+ const resultsContainer = document.getElementById('results-container');
460
+
461
  if (!termsText) {
462
  showError('inline', 'Please enter some terms');
463
  return;
464
  }
465
+
 
466
  const terms = termsText
467
+ .split(/\n+/)
468
  .map(t => t.trim())
469
  .filter(t => t.length > 0);
470
+
471
  if (terms.length === 0) {
472
  showError('inline', 'No valid terms found');
473
  return;
474
  }
475
+
476
  // Show loading
477
  loadingDiv.classList.add('show');
478
  btn.disabled = true;
479
+ resultsContainer.style.display = 'none';
480
+
481
+ // Clear previous results
482
+ ['clinical_bert', 'bert', 'word2vec'].forEach(m => {
483
+ document.getElementById(`viz-${m}`).innerHTML = '';
484
+ document.getElementById(`json-${m}`).value = '';
485
+ });
486
+
487
  try {
488
+ // Fetch all 3 models in parallel
489
+ const models = ['clinical_bert', 'bert', 'word2vec'];
490
+ const promises = models.map(m => fetchModelEmbeddings(m, terms, pooling)
491
+ .then(data => ({ status: 'fulfilled', model: m, data: data }))
492
+ .catch(err => ({ status: 'rejected', model: m, error: err }))
493
+ );
494
+
495
+ const results = await Promise.all(promises);
496
+
497
+ resultsContainer.style.display = 'block';
498
+
499
+ results.forEach(res => {
500
+ const jsonArea = document.getElementById(`json-${res.model}`);
501
+ if (res.status === 'fulfilled') {
502
+ jsonArea.value = JSON.stringify(res.data, null, 2);
503
+ createHeatmap(res.data.embeddings, terms, `viz-${res.model}`);
504
+ } else {
505
+ jsonArea.value = `Error: ${res.error.message}`;
506
+ }
507
  });
508
+
509
+ showSuccess('inline', `Generated embeddings for ${terms.length} terms across 3 models.`);
510
+
 
 
 
 
 
511
  } catch (error) {
512
+ showError('inline', `Critical Error: ${error.message}`);
 
513
  } finally {
514
  loadingDiv.classList.remove('show');
515
  btn.disabled = false;
516
  }
517
  }
518
+
519
  async function uploadFileEmbeddings() {
520
  const fileInput = document.getElementById('file-input');
521
  const pooling = document.getElementById('file-pooling').value;
522
+ const model = document.getElementById('file-model').value;
523
  const loadingDiv = document.getElementById('file-loading');
524
  const btn = document.getElementById('file-btn');
525
  const downloadSection = document.getElementById('download-section');
526
+
527
  if (!fileInput.files || fileInput.files.length === 0) {
528
  showError('file', 'Please select a CSV file');
529
  return;
530
  }
531
+
532
  const file = fileInput.files[0];
533
+
534
  // Show loading
535
  loadingDiv.classList.add('show');
536
  btn.disabled = true;
537
  downloadSection.classList.remove('show');
538
+
539
  try {
540
  const formData = new FormData();
541
  formData.append('file', file);
542
+
543
+ const response = await fetch(`${API_URL}/embeddings/file?pooling=${pooling}&model=${model}`, {
544
  method: 'POST',
545
  body: formData
546
  });
547
+
548
  if (!response.ok) {
549
  const errorData = await response.json();
550
  throw new Error(errorData.detail || `HTTP error! status: ${response.status}`);
551
  }
552
+
553
  // Get the blob for download
554
  downloadBlob = await response.blob();
555
  downloadFilename = `embeddings_${file.name}`;
556
+
557
  // Show download section
558
  downloadSection.classList.add('show');
559
  showSuccess('file', 'File processed successfully!');
 
565
  btn.disabled = false;
566
  }
567
  }
568
+
569
  function downloadResults() {
570
  if (!downloadBlob) {
571
  showError('file', 'No data to download');
572
  return;
573
  }
574
+
575
  const url = window.URL.createObjectURL(downloadBlob);
576
  const a = document.createElement('a');
577
  a.href = url;
 
583
  }
584
  </script>
585
  </body>
586
+
587
+ </html>
app/verify_backend.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from clinical_embedding import ModelManager
2
+ import numpy as np
3
+
4
+ def test_models():
5
+ mm = ModelManager()
6
+
7
+ print("Testing ClinicalBERT...")
8
+ cbert = mm.get_model('clinical_bert')
9
+ emb_cbert = cbert.get_embeddings(["Patient has [heart attack]."])
10
+ print(f"ClinicalBERT embedding shape: {emb_cbert.shape}")
11
+ assert emb_cbert.shape[1] == 768
12
+
13
+ print("\nTesting Standard BERT...")
14
+ bert = mm.get_model('bert')
15
+ emb_bert = bert.get_embeddings(["Patient has [heart attack]."])
16
+ print(f"Standard BERT embedding shape: {emb_bert.shape}")
17
+ assert emb_bert.shape[1] == 768
18
+
19
+ print("\nTesting Word2Vec (loading might fail if model not found, checking fail-safe)...")
20
+ try:
21
+ w2v = mm.get_model('word2vec')
22
+ if w2v.model:
23
+ emb_w2v = w2v.get_embeddings(["Patient has [heart attack]."])
24
+ print(f"Word2Vec embedding shape: {emb_w2v.shape}")
25
+ # Glove 50
26
+ if emb_w2v.size > 0:
27
+ assert emb_w2v.shape[1] == 50
28
+ else:
29
+ print("Word2Vec model could not be loaded (expected if no internet/file), skipping assertion.")
30
+ except Exception as e:
31
+ print(f"Word2Vec test error: {e}")
32
+
33
+ def test_brackets():
34
+ mm = ModelManager()
35
+ model = mm.get_model('clinical_bert')
36
+
37
+ # Test case 1: Context matters?
38
+ # Ideally "apple" in "eat [apple]" vs "company [apple]" might differ slightly in BERT even if focused?
39
+ # Actually if we extract only [apple], the context IS used in the forward pass, then we select tokens.
40
+
41
+ s1 = "I like to eat [apple] pie."
42
+ s2 = "I bought stock in [apple] computer."
43
+
44
+ emb1 = model.get_embeddings([s1])
45
+ emb2 = model.get_embeddings([s2])
46
+
47
+ # Compute cosine similarity
48
+ sim = np.dot(emb1[0], emb2[0]) / (np.linalg.norm(emb1[0]) * np.linalg.norm(emb2[0]))
49
+ print(f"\nSimilarity between '[apple]' in food context vs tech context: {sim:.4f}")
50
+
51
+ # If they are exactly 1.0, then context wasn't used effectively or they are just identical tokens.
52
+ # BERT contextual embeddings should differ.
53
+ if sim < 0.99:
54
+ print("SUCCESS: Embeddings are different (context aware).")
55
+ else:
56
+ print("WARNING: Embeddings are very similar. Might be expected if tokenization is identical and context weak, or logic flaw.")
57
+
58
+ # Test case 2: Brackets vs No Brackets
59
+ s3 = "heart attack"
60
+ s4 = "[heart attack]"
61
+ # Should be identical if s3 is sent as is?
62
+ # Wait, s3 "heart attack" -> full sentence embedding.
63
+ # s4 "[heart attack]" -> extract "heart attack", full sentence is "heart attack".
64
+ # They should be arguably the same.
65
+
66
+ emb3 = model.get_embeddings([s3])
67
+ emb4 = model.get_embeddings([s4])
68
+
69
+ sim_ident = np.dot(emb3[0], emb4[0]) / (np.linalg.norm(emb3[0]) * np.linalg.norm(emb4[0]))
70
+ print(f"Similarity between 'heart attack' and '[heart attack]': {sim_ident:.4f}")
71
+
72
+ # Test case 3: Bracket subset
73
+ s5 = "The patient had a [heart attack] yesterday."
74
+ emb5 = model.get_embeddings([s5])
75
+
76
+ # Compare emb5 (just heart attack) with emb3 (heart attack in isolation)
77
+ # They should be different because emb5 has context "The patient had a... yesterday"
78
+ sim_context = np.dot(emb5[0], emb3[0]) / (np.linalg.norm(emb5[0]) * np.linalg.norm(emb3[0]))
79
+ print(f"Similarity between 'heart attack' (isolated) and '...[heart attack]...' (context): {sim_context:.4f}")
80
+ assert sim_context < 0.99, "Context should affect embedding"
81
+
82
+ if __name__ == "__main__":
83
+ print("=== Running Backend Verification ===")
84
+ test_models()
85
+ test_brackets()
86
+ print("\n=== Verification Complete ===")
requirements.txt CHANGED
@@ -8,6 +8,8 @@ python-multipart==0.0.6
8
  transformers==4.35.2
9
  torch==2.1.1
10
  numpy==1.24.3
 
 
11
 
12
  # Optional: for GPU support, also install:
13
  # torch==2.1.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html
 
8
  transformers==4.35.2
9
  torch==2.1.1
10
  numpy==1.24.3
11
+ gensim==4.3.2
12
+ scikit-learn==1.3.2
13
 
14
  # Optional: for GPU support, also install:
15
  # torch==2.1.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html