Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -28,7 +28,7 @@ def load_model():
|
|
| 28 |
global model, tokenizer
|
| 29 |
|
| 30 |
try:
|
| 31 |
-
logger.info(f"Loading Qwen3
|
| 32 |
|
| 33 |
# Load tokenizer and model for Qwen3 embedding
|
| 34 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
@@ -44,7 +44,13 @@ def load_model():
|
|
| 44 |
|
| 45 |
model.eval()
|
| 46 |
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
return True
|
| 49 |
|
| 50 |
except Exception as e:
|
|
@@ -62,7 +68,7 @@ def load_model():
|
|
| 62 |
return False
|
| 63 |
|
| 64 |
def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
| 65 |
-
"""Generate embeddings for input text(s) using Qwen3
|
| 66 |
global model, tokenizer
|
| 67 |
|
| 68 |
try:
|
|
@@ -80,8 +86,9 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
|
|
| 80 |
|
| 81 |
for text in texts:
|
| 82 |
try:
|
| 83 |
-
# Method 1: Try using the
|
| 84 |
-
if model and tokenizer:
|
|
|
|
| 85 |
inputs = tokenizer(
|
| 86 |
text,
|
| 87 |
return_tensors="pt",
|
|
@@ -92,8 +99,25 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
|
|
| 92 |
|
| 93 |
with torch.no_grad():
|
| 94 |
outputs = model(**inputs)
|
| 95 |
-
|
| 96 |
-
embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
embeddings.append(embedding.tolist())
|
| 98 |
|
| 99 |
elif model and hasattr(model, 'encode'):
|
|
@@ -106,7 +130,7 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
|
|
| 106 |
except Exception as e:
|
| 107 |
logger.warning(f"Error generating embedding for text: {str(e)}")
|
| 108 |
# Return zero vector as last resort
|
| 109 |
-
embeddings.append([0.0] *
|
| 110 |
|
| 111 |
return embeddings[0] if single_text else embeddings
|
| 112 |
|
|
@@ -114,9 +138,9 @@ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List
|
|
| 114 |
logger.error(f"Error in generate_embeddings: {str(e)}")
|
| 115 |
# Return zero vectors as fallback
|
| 116 |
if single_text:
|
| 117 |
-
return [0.0] *
|
| 118 |
else:
|
| 119 |
-
return [[0.0] *
|
| 120 |
|
| 121 |
def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float:
|
| 122 |
"""Compute cosine similarity between two embeddings"""
|
|
@@ -193,7 +217,26 @@ def similarity_interface(embedding1: str, embedding2: str) -> float:
|
|
| 193 |
|
| 194 |
def health_check():
|
| 195 |
"""Health check endpoint"""
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
# Create FastAPI application
|
| 199 |
app = FastAPI(
|
|
|
|
| 28 |
global model, tokenizer
|
| 29 |
|
| 30 |
try:
|
| 31 |
+
logger.info(f"Loading Qwen3-Embedding-0.6B model on device: {DEVICE}")
|
| 32 |
|
| 33 |
# Load tokenizer and model for Qwen3 embedding
|
| 34 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
|
|
| 44 |
|
| 45 |
model.eval()
|
| 46 |
|
| 47 |
+
# Test the model with a simple input
|
| 48 |
+
test_input = tokenizer("test", return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(DEVICE)
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
test_output = model(**test_input)
|
| 51 |
+
logger.info(f"Model test successful. Output shape: {test_output.last_hidden_state.shape}")
|
| 52 |
+
|
| 53 |
+
logger.info("Qwen3-Embedding-0.6B model loaded successfully")
|
| 54 |
return True
|
| 55 |
|
| 56 |
except Exception as e:
|
|
|
|
| 68 |
return False
|
| 69 |
|
| 70 |
def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
|
| 71 |
+
"""Generate embeddings for input text(s) using Qwen3-Embedding-0.6B model"""
|
| 72 |
global model, tokenizer
|
| 73 |
|
| 74 |
try:
|
|
|
|
| 86 |
|
| 87 |
for text in texts:
|
| 88 |
try:
|
| 89 |
+
# Method 1: Try using the Qwen3 embedding model directly
|
| 90 |
+
if model and tokenizer and hasattr(model, 'forward'):
|
| 91 |
+
# This is the Qwen3 embedding model
|
| 92 |
inputs = tokenizer(
|
| 93 |
text,
|
| 94 |
return_tensors="pt",
|
|
|
|
| 99 |
|
| 100 |
with torch.no_grad():
|
| 101 |
outputs = model(**inputs)
|
| 102 |
+
|
| 103 |
+
# For Qwen3 embedding models, use the last_hidden_state with mean pooling
|
| 104 |
+
if hasattr(outputs, 'last_hidden_state'):
|
| 105 |
+
# Mean pooling over the sequence length dimension
|
| 106 |
+
attention_mask = inputs.get('attention_mask', None)
|
| 107 |
+
if attention_mask is not None:
|
| 108 |
+
# Apply attention mask for proper mean pooling
|
| 109 |
+
token_embeddings = outputs.last_hidden_state
|
| 110 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 111 |
+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
| 112 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
| 113 |
+
embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy()
|
| 114 |
+
else:
|
| 115 |
+
# Simple mean pooling without attention mask
|
| 116 |
+
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
|
| 117 |
+
else:
|
| 118 |
+
# Fallback to pooled output if available
|
| 119 |
+
embedding = outputs.pooler_output.squeeze().cpu().numpy()
|
| 120 |
+
|
| 121 |
embeddings.append(embedding.tolist())
|
| 122 |
|
| 123 |
elif model and hasattr(model, 'encode'):
|
|
|
|
| 130 |
except Exception as e:
|
| 131 |
logger.warning(f"Error generating embedding for text: {str(e)}")
|
| 132 |
# Return zero vector as last resort
|
| 133 |
+
embeddings.append([0.0] * 1024) # Qwen3-Embedding-0.6B has 1024 dimensions
|
| 134 |
|
| 135 |
return embeddings[0] if single_text else embeddings
|
| 136 |
|
|
|
|
| 138 |
logger.error(f"Error in generate_embeddings: {str(e)}")
|
| 139 |
# Return zero vectors as fallback
|
| 140 |
if single_text:
|
| 141 |
+
return [0.0] * 1024
|
| 142 |
else:
|
| 143 |
+
return [[0.0] * 1024] * len(texts)
|
| 144 |
|
| 145 |
def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float:
|
| 146 |
"""Compute cosine similarity between two embeddings"""
|
|
|
|
| 217 |
|
| 218 |
def health_check():
|
| 219 |
"""Health check endpoint"""
|
| 220 |
+
model_info = {
|
| 221 |
+
"status": "healthy" if model is not None else "unhealthy",
|
| 222 |
+
"model_loaded": model is not None,
|
| 223 |
+
"model_name": MODEL_NAME,
|
| 224 |
+
"device": DEVICE,
|
| 225 |
+
"max_length": MAX_LENGTH
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
if model is not None:
|
| 229 |
+
if hasattr(model, 'config'):
|
| 230 |
+
model_info["model_type"] = "Qwen3-Embedding"
|
| 231 |
+
model_info["embedding_dimension"] = getattr(model.config, 'hidden_size', 1024)
|
| 232 |
+
elif hasattr(model, 'encode'):
|
| 233 |
+
model_info["model_type"] = "SentenceTransformer-Fallback"
|
| 234 |
+
model_info["embedding_dimension"] = 384
|
| 235 |
+
else:
|
| 236 |
+
model_info["model_type"] = "Unknown"
|
| 237 |
+
model_info["embedding_dimension"] = "Unknown"
|
| 238 |
+
|
| 239 |
+
return model_info
|
| 240 |
|
| 241 |
# Create FastAPI application
|
| 242 |
app = FastAPI(
|