Spaces:
Running
Running
File size: 15,135 Bytes
14cf01c 8ede5e9 3b74c11 8ede5e9 14cf01c 8ede5e9 6f6fd79 8ede5e9 9e9c055 8ede5e9 3b74c11 8ede5e9 9e9c055 3d61fba 3b74c11 9e9c055 8ede5e9 3d61fba 020892f 8ede5e9 9e9c055 8ede5e9 3d61fba 8ede5e9 3d61fba 8ede5e9 3d61fba 9e9c055 3d61fba db3fd97 3d61fba db3fd97 8ede5e9 3d61fba 8ede5e9 3d61fba 8ede5e9 9e9c055 3d61fba 9e9c055 3d61fba 9e9c055 3d61fba 9e9c055 3d61fba 9e9c055 8ede5e9 14cf01c 8ede5e9 14cf01c 020892f 68092ea 020892f 8ede5e9 020892f 14cf01c 020892f 14cf01c 020892f 14cf01c 020892f 8ede5e9 14cf01c 020892f 14cf01c 020892f 8ede5e9 020892f 8ede5e9 020892f 14cf01c 8ede5e9 14cf01c 8ede5e9 14cf01c 8ede5e9 14cf01c 8ede5e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel, AutoConfig
from typing import List, Union
import json
import logging
import os
import time
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Model configuration - Qwen3 Embedding model
MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Qwen3 Embedding model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 512
# Global variables for model and tokenizer
model = None
tokenizer = None
def load_model():
"""Load the Qwen3 embedding model and tokenizer"""
global model, tokenizer
try:
logger.info(f"Loading Qwen3-Embedding-0.6B model on device: {DEVICE}")
# Load tokenizer and model for Qwen3 embedding
# First, try to load the config to understand the model structure
config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
logger.info(f"Model config loaded: {config.model_type}")
# Load tokenizer - try different approaches
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
except Exception as tokenizer_error:
logger.warning(f"Failed to load tokenizer with trust_remote_code=True: {tokenizer_error}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=False)
# Load model
model = AutoModel.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto" if DEVICE == "cuda" else None
)
if DEVICE == "cpu":
model = model.to(DEVICE)
model.eval()
# Test the model with a simple input
test_input = tokenizer("test", return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(DEVICE)
with torch.no_grad():
test_output = model(**test_input)
logger.info(f"Model test successful. Output shape: {test_output.last_hidden_state.shape}")
logger.info(f"Model config hidden size: {model.config.hidden_size}")
logger.info(f"Tokenizer vocab size: {tokenizer.vocab_size}")
logger.info("Qwen3-Embedding-0.6B model loaded successfully")
return True
except Exception as e:
logger.error(f"Error loading Qwen3 model: {str(e)}")
logger.error("No fallback available - Qwen3 model is required")
return False
def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
"""Generate embeddings for input text(s) using Qwen3-Embedding-0.6B model"""
global model, tokenizer
if not model or not tokenizer:
raise Exception("Qwen3 model not loaded. Please ensure the model is properly loaded.")
try:
# Ensure texts is a list
if isinstance(texts, str):
texts = [texts]
single_text = True
else:
single_text = False
# Truncate texts if too long
texts = [text[:MAX_LENGTH] for text in texts]
embeddings = []
for text in texts:
try:
# Use the Qwen3 embedding model directly
inputs = tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_LENGTH
).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
# For Qwen3 embedding models, use the last_hidden_state with mean pooling
if hasattr(outputs, 'last_hidden_state'):
# Mean pooling over the sequence length dimension
attention_mask = inputs.get('attention_mask', None)
if attention_mask is not None:
# Apply attention mask for proper mean pooling
token_embeddings = outputs.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy()
else:
# Simple mean pooling without attention mask
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
else:
# Fallback to pooled output if available
embedding = outputs.pooler_output.squeeze().cpu().numpy()
embeddings.append(embedding.tolist())
except Exception as e:
logger.error(f"Error generating embedding for text: {str(e)}")
raise Exception(f"Failed to generate embedding: {str(e)}")
return embeddings[0] if single_text else embeddings
except Exception as e:
logger.error(f"Error in generate_embeddings: {str(e)}")
raise Exception(f"Embedding generation failed: {str(e)}")
def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float:
"""Compute cosine similarity between two embeddings"""
try:
# Convert to numpy arrays
emb1 = np.array(embedding1)
emb2 = np.array(embedding2)
# Compute cosine similarity
dot_product = np.dot(emb1, emb2)
norm1 = np.linalg.norm(emb1)
norm2 = np.linalg.norm(emb2)
if norm1 == 0 or norm2 == 0:
return 0.0
similarity = dot_product / (norm1 * norm2)
return float(similarity)
except Exception as e:
logger.error(f"Error computing similarity: {str(e)}")
return 0.0
def batch_embedding_interface(texts: str) -> str:
"""Interface for batch embedding generation"""
try:
# Split texts by newlines
text_list = [text.strip() for text in texts.split('\n') if text.strip()]
if not text_list:
return json.dumps([])
# Generate embeddings
embeddings = generate_embeddings(text_list)
# Return as JSON string
return json.dumps(embeddings)
except Exception as e:
logger.error(f"Error in batch_embedding_interface: {str(e)}")
return json.dumps([])
def single_embedding_interface(text: str) -> str:
"""Interface for single embedding generation"""
try:
if not text.strip():
return json.dumps([])
# Generate embedding
embedding = generate_embeddings(text)
# Return as JSON string
return json.dumps(embedding)
except Exception as e:
logger.error(f"Error in single_embedding_interface: {str(e)}")
return json.dumps([])
def similarity_interface(embedding1: str, embedding2: str) -> float:
"""Interface for computing similarity between two embeddings"""
try:
# Parse embeddings from JSON strings
emb1 = json.loads(embedding1)
emb2 = json.loads(embedding2)
# Compute similarity
similarity = compute_similarity(emb1, emb2)
return similarity
except Exception as e:
logger.error(f"Error in similarity_interface: {str(e)}")
return 0.0
def health_check():
"""Health check endpoint"""
model_info = {
"status": "healthy" if model is not None and tokenizer is not None else "unhealthy",
"model_loaded": model is not None and tokenizer is not None,
"model_name": MODEL_NAME,
"device": DEVICE,
"max_length": MAX_LENGTH
}
if model is not None and tokenizer is not None:
if hasattr(model, 'config'):
model_info["model_type"] = "Qwen3-Embedding"
model_info["embedding_dimension"] = getattr(model.config, 'hidden_size', 1024)
model_info["tokenizer_loaded"] = True
else:
model_info["model_type"] = "Unknown"
model_info["embedding_dimension"] = "Unknown"
model_info["tokenizer_loaded"] = False
else:
model_info["model_type"] = "Not Loaded"
model_info["embedding_dimension"] = "N/A"
model_info["tokenizer_loaded"] = tokenizer is not None
return model_info
# Create FastAPI application
app = FastAPI(
title="Qwen3 Embedding API",
description="A stable API for generating text embeddings using the Qwen3-Embedding-0.6B model",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# FastAPI endpoints
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "Qwen3 Embedding API",
"version": "1.0.0",
"model": "Qwen3-Embedding-0.6B",
"endpoints": {
"health": "/health",
"predict": "/api/predict",
"docs": "/docs"
}
}
@app.get("/health")
async def health():
"""Health check endpoint"""
return health_check()
@app.post("/api/predict")
async def predict(data: dict):
"""Main prediction endpoint for embeddings"""
try:
# Check for new format first (texts parameter)
if "texts" in data:
texts = data["texts"]
normalize = data.get("normalize", True)
if not isinstance(texts, list):
raise HTTPException(status_code=400, detail="'texts' must be a list")
if len(texts) == 0:
raise HTTPException(status_code=400, detail="'texts' list cannot be empty")
# Generate embeddings
logger.info(f"Generating embeddings for {len(texts)} texts")
embeddings = generate_embeddings(texts)
logger.info(f"Generated {len(embeddings)} embeddings with dimension {len(embeddings[0]) if embeddings else 0}")
# Normalize embeddings if requested
if normalize:
import numpy as np
try:
embeddings = [emb / np.linalg.norm(emb) for emb in embeddings]
logger.info("Embeddings normalized")
except Exception as norm_error:
logger.warning(f"Normalization failed: {str(norm_error)}, returning unnormalized embeddings")
# Continue with unnormalized embeddings
return {
"embeddings": embeddings,
"model": MODEL_NAME,
"usage": {
"prompt_tokens": sum(len(text.split()) for text in texts),
"total_tokens": sum(len(text.split()) for text in texts)
}
}
# Fallback to old format for backward compatibility
elif "data" in data:
input_data = data["data"]
# Handle single text or batch texts
if isinstance(input_data, str):
# Single text
embeddings = generate_embeddings(input_data)
return {"data": [embeddings]}
elif isinstance(input_data, list):
if len(input_data) > 0 and isinstance(input_data[0], str):
# Single text in list
embeddings = generate_embeddings(input_data[0])
return {"data": [embeddings]}
elif len(input_data) > 0 and isinstance(input_data[0], list):
# Batch texts
embeddings = generate_embeddings(input_data[0])
return {"data": [embeddings]}
else:
raise HTTPException(status_code=400, detail="Invalid data format")
else:
raise HTTPException(status_code=400, detail="Invalid data type")
else:
raise HTTPException(status_code=400, detail="Missing 'texts' or 'data' field in request")
except Exception as e:
logger.error(f"Error in predict endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/api/similarity")
async def similarity(data: dict):
"""Compute similarity between two texts or embeddings"""
try:
# Check for new format first (text1, text2 parameters)
if "text1" in data and "text2" in data:
text1 = data["text1"]
text2 = data["text2"]
if not isinstance(text1, str) or not isinstance(text2, str):
raise HTTPException(status_code=400, detail="text1 and text2 must be strings")
# Generate embeddings for both texts
emb1 = generate_embeddings(text1)
emb2 = generate_embeddings(text2)
# Compute similarity
sim = compute_similarity(emb1, emb2)
return {
"similarity": sim,
"model": MODEL_NAME,
"text1": text1,
"text2": text2
}
# Fallback to old format (embedding1, embedding2 parameters)
elif "embedding1" in data and "embedding2" in data:
emb1 = data["embedding1"]
emb2 = data["embedding2"]
if not isinstance(emb1, list) or not isinstance(emb2, list):
raise HTTPException(status_code=400, detail="Embeddings must be lists")
sim = compute_similarity(emb1, emb2)
return {"similarity": sim}
else:
raise HTTPException(status_code=400, detail="Missing 'text1' and 'text2' or 'embedding1' and 'embedding2' fields")
except Exception as e:
logger.error(f"Error in similarity endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
def main():
"""Main function to run the application"""
logger.info("Starting Qwen3 Embedding Model API...")
# Load model
if not load_model():
logger.error("Failed to load model. Exiting...")
return
logger.info("Model loaded successfully. Starting FastAPI server...")
# Run with uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)
if __name__ == "__main__":
main()
|