from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware import torch import numpy as np from transformers import AutoTokenizer, AutoModel, AutoConfig from typing import List, Union import json import logging import os import time import uvicorn # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Model configuration - Qwen3 Embedding model MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Qwen3 Embedding model DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MAX_LENGTH = 512 # Global variables for model and tokenizer model = None tokenizer = None def load_model(): """Load the Qwen3 embedding model and tokenizer""" global model, tokenizer try: logger.info(f"Loading Qwen3-Embedding-0.6B model on device: {DEVICE}") # Load tokenizer and model for Qwen3 embedding # First, try to load the config to understand the model structure config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True) logger.info(f"Model config loaded: {config.model_type}") # Load tokenizer - try different approaches try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) except Exception as tokenizer_error: logger.warning(f"Failed to load tokenizer with trust_remote_code=True: {tokenizer_error}") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=False) # Load model model = AutoModel.from_pretrained( MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, device_map="auto" if DEVICE == "cuda" else None ) if DEVICE == "cpu": model = model.to(DEVICE) model.eval() # Test the model with a simple input test_input = tokenizer("test", return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(DEVICE) with torch.no_grad(): test_output = model(**test_input) logger.info(f"Model test successful. Output shape: {test_output.last_hidden_state.shape}") logger.info(f"Model config hidden size: {model.config.hidden_size}") logger.info(f"Tokenizer vocab size: {tokenizer.vocab_size}") logger.info("Qwen3-Embedding-0.6B model loaded successfully") return True except Exception as e: logger.error(f"Error loading Qwen3 model: {str(e)}") logger.error("No fallback available - Qwen3 model is required") return False def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: """Generate embeddings for input text(s) using Qwen3-Embedding-0.6B model""" global model, tokenizer if not model or not tokenizer: raise Exception("Qwen3 model not loaded. Please ensure the model is properly loaded.") try: # Ensure texts is a list if isinstance(texts, str): texts = [texts] single_text = True else: single_text = False # Truncate texts if too long texts = [text[:MAX_LENGTH] for text in texts] embeddings = [] for text in texts: try: # Use the Qwen3 embedding model directly inputs = tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH ).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) # For Qwen3 embedding models, use the last_hidden_state with mean pooling if hasattr(outputs, 'last_hidden_state'): # Mean pooling over the sequence length dimension attention_mask = inputs.get('attention_mask', None) if attention_mask is not None: # Apply attention mask for proper mean pooling token_embeddings = outputs.last_hidden_state input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy() else: # Simple mean pooling without attention mask embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() else: # Fallback to pooled output if available embedding = outputs.pooler_output.squeeze().cpu().numpy() embeddings.append(embedding.tolist()) except Exception as e: logger.error(f"Error generating embedding for text: {str(e)}") raise Exception(f"Failed to generate embedding: {str(e)}") return embeddings[0] if single_text else embeddings except Exception as e: logger.error(f"Error in generate_embeddings: {str(e)}") raise Exception(f"Embedding generation failed: {str(e)}") def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float: """Compute cosine similarity between two embeddings""" try: # Convert to numpy arrays emb1 = np.array(embedding1) emb2 = np.array(embedding2) # Compute cosine similarity dot_product = np.dot(emb1, emb2) norm1 = np.linalg.norm(emb1) norm2 = np.linalg.norm(emb2) if norm1 == 0 or norm2 == 0: return 0.0 similarity = dot_product / (norm1 * norm2) return float(similarity) except Exception as e: logger.error(f"Error computing similarity: {str(e)}") return 0.0 def batch_embedding_interface(texts: str) -> str: """Interface for batch embedding generation""" try: # Split texts by newlines text_list = [text.strip() for text in texts.split('\n') if text.strip()] if not text_list: return json.dumps([]) # Generate embeddings embeddings = generate_embeddings(text_list) # Return as JSON string return json.dumps(embeddings) except Exception as e: logger.error(f"Error in batch_embedding_interface: {str(e)}") return json.dumps([]) def single_embedding_interface(text: str) -> str: """Interface for single embedding generation""" try: if not text.strip(): return json.dumps([]) # Generate embedding embedding = generate_embeddings(text) # Return as JSON string return json.dumps(embedding) except Exception as e: logger.error(f"Error in single_embedding_interface: {str(e)}") return json.dumps([]) def similarity_interface(embedding1: str, embedding2: str) -> float: """Interface for computing similarity between two embeddings""" try: # Parse embeddings from JSON strings emb1 = json.loads(embedding1) emb2 = json.loads(embedding2) # Compute similarity similarity = compute_similarity(emb1, emb2) return similarity except Exception as e: logger.error(f"Error in similarity_interface: {str(e)}") return 0.0 def health_check(): """Health check endpoint""" model_info = { "status": "healthy" if model is not None and tokenizer is not None else "unhealthy", "model_loaded": model is not None and tokenizer is not None, "model_name": MODEL_NAME, "device": DEVICE, "max_length": MAX_LENGTH } if model is not None and tokenizer is not None: if hasattr(model, 'config'): model_info["model_type"] = "Qwen3-Embedding" model_info["embedding_dimension"] = getattr(model.config, 'hidden_size', 1024) model_info["tokenizer_loaded"] = True else: model_info["model_type"] = "Unknown" model_info["embedding_dimension"] = "Unknown" model_info["tokenizer_loaded"] = False else: model_info["model_type"] = "Not Loaded" model_info["embedding_dimension"] = "N/A" model_info["tokenizer_loaded"] = tokenizer is not None return model_info # Create FastAPI application app = FastAPI( title="Qwen3 Embedding API", description="A stable API for generating text embeddings using the Qwen3-Embedding-0.6B model", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # FastAPI endpoints @app.get("/") async def root(): """Root endpoint with API information""" return { "message": "Qwen3 Embedding API", "version": "1.0.0", "model": "Qwen3-Embedding-0.6B", "endpoints": { "health": "/health", "predict": "/api/predict", "docs": "/docs" } } @app.get("/health") async def health(): """Health check endpoint""" return health_check() @app.post("/api/predict") async def predict(data: dict): """Main prediction endpoint for embeddings""" try: # Check for new format first (texts parameter) if "texts" in data: texts = data["texts"] normalize = data.get("normalize", True) if not isinstance(texts, list): raise HTTPException(status_code=400, detail="'texts' must be a list") if len(texts) == 0: raise HTTPException(status_code=400, detail="'texts' list cannot be empty") # Generate embeddings logger.info(f"Generating embeddings for {len(texts)} texts") embeddings = generate_embeddings(texts) logger.info(f"Generated {len(embeddings)} embeddings with dimension {len(embeddings[0]) if embeddings else 0}") # Normalize embeddings if requested if normalize: import numpy as np try: embeddings = [emb / np.linalg.norm(emb) for emb in embeddings] logger.info("Embeddings normalized") except Exception as norm_error: logger.warning(f"Normalization failed: {str(norm_error)}, returning unnormalized embeddings") # Continue with unnormalized embeddings return { "embeddings": embeddings, "model": MODEL_NAME, "usage": { "prompt_tokens": sum(len(text.split()) for text in texts), "total_tokens": sum(len(text.split()) for text in texts) } } # Fallback to old format for backward compatibility elif "data" in data: input_data = data["data"] # Handle single text or batch texts if isinstance(input_data, str): # Single text embeddings = generate_embeddings(input_data) return {"data": [embeddings]} elif isinstance(input_data, list): if len(input_data) > 0 and isinstance(input_data[0], str): # Single text in list embeddings = generate_embeddings(input_data[0]) return {"data": [embeddings]} elif len(input_data) > 0 and isinstance(input_data[0], list): # Batch texts embeddings = generate_embeddings(input_data[0]) return {"data": [embeddings]} else: raise HTTPException(status_code=400, detail="Invalid data format") else: raise HTTPException(status_code=400, detail="Invalid data type") else: raise HTTPException(status_code=400, detail="Missing 'texts' or 'data' field in request") except Exception as e: logger.error(f"Error in predict endpoint: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") @app.post("/api/similarity") async def similarity(data: dict): """Compute similarity between two texts or embeddings""" try: # Check for new format first (text1, text2 parameters) if "text1" in data and "text2" in data: text1 = data["text1"] text2 = data["text2"] if not isinstance(text1, str) or not isinstance(text2, str): raise HTTPException(status_code=400, detail="text1 and text2 must be strings") # Generate embeddings for both texts emb1 = generate_embeddings(text1) emb2 = generate_embeddings(text2) # Compute similarity sim = compute_similarity(emb1, emb2) return { "similarity": sim, "model": MODEL_NAME, "text1": text1, "text2": text2 } # Fallback to old format (embedding1, embedding2 parameters) elif "embedding1" in data and "embedding2" in data: emb1 = data["embedding1"] emb2 = data["embedding2"] if not isinstance(emb1, list) or not isinstance(emb2, list): raise HTTPException(status_code=400, detail="Embeddings must be lists") sim = compute_similarity(emb1, emb2) return {"similarity": sim} else: raise HTTPException(status_code=400, detail="Missing 'text1' and 'text2' or 'embedding1' and 'embedding2' fields") except Exception as e: logger.error(f"Error in similarity endpoint: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") def main(): """Main function to run the application""" logger.info("Starting Qwen3 Embedding Model API...") # Load model if not load_model(): logger.error("Failed to load model. Exiting...") return logger.info("Model loaded successfully. Starting FastAPI server...") # Run with uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, log_level="info" ) if __name__ == "__main__": main()