Ojochegbeng commited on
Commit
9e9c055
·
verified ·
1 Parent(s): fc9c99f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -11
app.py CHANGED
@@ -28,7 +28,7 @@ def load_model():
28
  global model, tokenizer
29
 
30
  try:
31
- logger.info(f"Loading Qwen3 embedding model on device: {DEVICE}")
32
 
33
  # Load tokenizer and model for Qwen3 embedding
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
@@ -44,7 +44,13 @@ def load_model():
44
 
45
  model.eval()
46
 
47
- logger.info("Qwen3 embedding model loaded successfully")
 
 
 
 
 
 
48
  return True
49
 
50
  except Exception as e:
@@ -62,7 +68,7 @@ def load_model():
62
  return False
63
 
64
  def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
65
- """Generate embeddings for input text(s) using Qwen3 or fallback model"""
66
  global model, tokenizer
67
 
68
  try:
@@ -80,8 +86,9 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
80
 
81
  for text in texts:
82
  try:
83
- # Method 1: Try using the Qwen model directly
84
- if model and tokenizer:
 
85
  inputs = tokenizer(
86
  text,
87
  return_tensors="pt",
@@ -92,8 +99,25 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
92
 
93
  with torch.no_grad():
94
  outputs = model(**inputs)
95
- # Use mean pooling of last hidden state
96
- embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  embeddings.append(embedding.tolist())
98
 
99
  elif model and hasattr(model, 'encode'):
@@ -106,7 +130,7 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
106
  except Exception as e:
107
  logger.warning(f"Error generating embedding for text: {str(e)}")
108
  # Return zero vector as last resort
109
- embeddings.append([0.0] * 384) # Standard dimension for fallback
110
 
111
  return embeddings[0] if single_text else embeddings
112
 
@@ -114,9 +138,9 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
114
  logger.error(f"Error in generate_embeddings: {str(e)}")
115
  # Return zero vectors as fallback
116
  if single_text:
117
- return [0.0] * 384
118
  else:
119
- return [[0.0] * 384] * len(texts)
120
 
121
  def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float:
122
  """Compute cosine similarity between two embeddings"""
@@ -193,7 +217,26 @@ def similarity_interface(embedding1: str, embedding2: str) -> float:
193
 
194
  def health_check():
195
  """Health check endpoint"""
196
- return {"status": "healthy", "model_loaded": model is not None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  # Create FastAPI application
199
  app = FastAPI(
 
28
  global model, tokenizer
29
 
30
  try:
31
+ logger.info(f"Loading Qwen3-Embedding-0.6B model on device: {DEVICE}")
32
 
33
  # Load tokenizer and model for Qwen3 embedding
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
44
 
45
  model.eval()
46
 
47
+ # Test the model with a simple input
48
+ test_input = tokenizer("test", return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(DEVICE)
49
+ with torch.no_grad():
50
+ test_output = model(**test_input)
51
+ logger.info(f"Model test successful. Output shape: {test_output.last_hidden_state.shape}")
52
+
53
+ logger.info("Qwen3-Embedding-0.6B model loaded successfully")
54
  return True
55
 
56
  except Exception as e:
 
68
  return False
69
 
70
  def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
71
+ """Generate embeddings for input text(s) using Qwen3-Embedding-0.6B model"""
72
  global model, tokenizer
73
 
74
  try:
 
86
 
87
  for text in texts:
88
  try:
89
+ # Method 1: Try using the Qwen3 embedding model directly
90
+ if model and tokenizer and hasattr(model, 'forward'):
91
+ # This is the Qwen3 embedding model
92
  inputs = tokenizer(
93
  text,
94
  return_tensors="pt",
 
99
 
100
  with torch.no_grad():
101
  outputs = model(**inputs)
102
+
103
+ # For Qwen3 embedding models, use the last_hidden_state with mean pooling
104
+ if hasattr(outputs, 'last_hidden_state'):
105
+ # Mean pooling over the sequence length dimension
106
+ attention_mask = inputs.get('attention_mask', None)
107
+ if attention_mask is not None:
108
+ # Apply attention mask for proper mean pooling
109
+ token_embeddings = outputs.last_hidden_state
110
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
111
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
112
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
113
+ embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy()
114
+ else:
115
+ # Simple mean pooling without attention mask
116
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
117
+ else:
118
+ # Fallback to pooled output if available
119
+ embedding = outputs.pooler_output.squeeze().cpu().numpy()
120
+
121
  embeddings.append(embedding.tolist())
122
 
123
  elif model and hasattr(model, 'encode'):
 
130
  except Exception as e:
131
  logger.warning(f"Error generating embedding for text: {str(e)}")
132
  # Return zero vector as last resort
133
+ embeddings.append([0.0] * 1024) # Qwen3-Embedding-0.6B has 1024 dimensions
134
 
135
  return embeddings[0] if single_text else embeddings
136
 
 
138
  logger.error(f"Error in generate_embeddings: {str(e)}")
139
  # Return zero vectors as fallback
140
  if single_text:
141
+ return [0.0] * 1024
142
  else:
143
+ return [[0.0] * 1024] * len(texts)
144
 
145
  def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float:
146
  """Compute cosine similarity between two embeddings"""
 
217
 
218
  def health_check():
219
  """Health check endpoint"""
220
+ model_info = {
221
+ "status": "healthy" if model is not None else "unhealthy",
222
+ "model_loaded": model is not None,
223
+ "model_name": MODEL_NAME,
224
+ "device": DEVICE,
225
+ "max_length": MAX_LENGTH
226
+ }
227
+
228
+ if model is not None:
229
+ if hasattr(model, 'config'):
230
+ model_info["model_type"] = "Qwen3-Embedding"
231
+ model_info["embedding_dimension"] = getattr(model.config, 'hidden_size', 1024)
232
+ elif hasattr(model, 'encode'):
233
+ model_info["model_type"] = "SentenceTransformer-Fallback"
234
+ model_info["embedding_dimension"] = 384
235
+ else:
236
+ model_info["model_type"] = "Unknown"
237
+ model_info["embedding_dimension"] = "Unknown"
238
+
239
+ return model_info
240
 
241
  # Create FastAPI application
242
  app = FastAPI(