Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel, AutoConfig | |
| from typing import List, Union | |
| import json | |
| import logging | |
| import os | |
| import time | |
| import uvicorn | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Model configuration - Qwen3 Embedding model | |
| MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Qwen3 Embedding model | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_LENGTH = 512 | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| """Load the Qwen3 embedding model and tokenizer""" | |
| global model, tokenizer | |
| try: | |
| logger.info(f"Loading Qwen3-Embedding-0.6B model on device: {DEVICE}") | |
| # Load tokenizer and model for Qwen3 embedding | |
| # First, try to load the config to understand the model structure | |
| config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| logger.info(f"Model config loaded: {config.model_type}") | |
| # Load tokenizer - try different approaches | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| except Exception as tokenizer_error: | |
| logger.warning(f"Failed to load tokenizer with trust_remote_code=True: {tokenizer_error}") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=False) | |
| # Load model | |
| model = AutoModel.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" if DEVICE == "cuda" else None | |
| ) | |
| if DEVICE == "cpu": | |
| model = model.to(DEVICE) | |
| model.eval() | |
| # Test the model with a simple input | |
| test_input = tokenizer("test", return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(DEVICE) | |
| with torch.no_grad(): | |
| test_output = model(**test_input) | |
| logger.info(f"Model test successful. Output shape: {test_output.last_hidden_state.shape}") | |
| logger.info(f"Model config hidden size: {model.config.hidden_size}") | |
| logger.info(f"Tokenizer vocab size: {tokenizer.vocab_size}") | |
| logger.info("Qwen3-Embedding-0.6B model loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading Qwen3 model: {str(e)}") | |
| logger.error("No fallback available - Qwen3 model is required") | |
| return False | |
| def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: | |
| """Generate embeddings for input text(s) using Qwen3-Embedding-0.6B model""" | |
| global model, tokenizer | |
| if not model or not tokenizer: | |
| raise Exception("Qwen3 model not loaded. Please ensure the model is properly loaded.") | |
| try: | |
| # Ensure texts is a list | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| single_text = True | |
| else: | |
| single_text = False | |
| # Truncate texts if too long | |
| texts = [text[:MAX_LENGTH] for text in texts] | |
| embeddings = [] | |
| for text in texts: | |
| try: | |
| # Use the Qwen3 embedding model directly | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=MAX_LENGTH | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # For Qwen3 embedding models, use the last_hidden_state with mean pooling | |
| if hasattr(outputs, 'last_hidden_state'): | |
| # Mean pooling over the sequence length dimension | |
| attention_mask = inputs.get('attention_mask', None) | |
| if attention_mask is not None: | |
| # Apply attention mask for proper mean pooling | |
| token_embeddings = outputs.last_hidden_state | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | |
| sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy() | |
| else: | |
| # Simple mean pooling without attention mask | |
| embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() | |
| else: | |
| # Fallback to pooled output if available | |
| embedding = outputs.pooler_output.squeeze().cpu().numpy() | |
| embeddings.append(embedding.tolist()) | |
| except Exception as e: | |
| logger.error(f"Error generating embedding for text: {str(e)}") | |
| raise Exception(f"Failed to generate embedding: {str(e)}") | |
| return embeddings[0] if single_text else embeddings | |
| except Exception as e: | |
| logger.error(f"Error in generate_embeddings: {str(e)}") | |
| raise Exception(f"Embedding generation failed: {str(e)}") | |
| def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float: | |
| """Compute cosine similarity between two embeddings""" | |
| try: | |
| # Convert to numpy arrays | |
| emb1 = np.array(embedding1) | |
| emb2 = np.array(embedding2) | |
| # Compute cosine similarity | |
| dot_product = np.dot(emb1, emb2) | |
| norm1 = np.linalg.norm(emb1) | |
| norm2 = np.linalg.norm(emb2) | |
| if norm1 == 0 or norm2 == 0: | |
| return 0.0 | |
| similarity = dot_product / (norm1 * norm2) | |
| return float(similarity) | |
| except Exception as e: | |
| logger.error(f"Error computing similarity: {str(e)}") | |
| return 0.0 | |
| def batch_embedding_interface(texts: str) -> str: | |
| """Interface for batch embedding generation""" | |
| try: | |
| # Split texts by newlines | |
| text_list = [text.strip() for text in texts.split('\n') if text.strip()] | |
| if not text_list: | |
| return json.dumps([]) | |
| # Generate embeddings | |
| embeddings = generate_embeddings(text_list) | |
| # Return as JSON string | |
| return json.dumps(embeddings) | |
| except Exception as e: | |
| logger.error(f"Error in batch_embedding_interface: {str(e)}") | |
| return json.dumps([]) | |
| def single_embedding_interface(text: str) -> str: | |
| """Interface for single embedding generation""" | |
| try: | |
| if not text.strip(): | |
| return json.dumps([]) | |
| # Generate embedding | |
| embedding = generate_embeddings(text) | |
| # Return as JSON string | |
| return json.dumps(embedding) | |
| except Exception as e: | |
| logger.error(f"Error in single_embedding_interface: {str(e)}") | |
| return json.dumps([]) | |
| def similarity_interface(embedding1: str, embedding2: str) -> float: | |
| """Interface for computing similarity between two embeddings""" | |
| try: | |
| # Parse embeddings from JSON strings | |
| emb1 = json.loads(embedding1) | |
| emb2 = json.loads(embedding2) | |
| # Compute similarity | |
| similarity = compute_similarity(emb1, emb2) | |
| return similarity | |
| except Exception as e: | |
| logger.error(f"Error in similarity_interface: {str(e)}") | |
| return 0.0 | |
| def health_check(): | |
| """Health check endpoint""" | |
| model_info = { | |
| "status": "healthy" if model is not None and tokenizer is not None else "unhealthy", | |
| "model_loaded": model is not None and tokenizer is not None, | |
| "model_name": MODEL_NAME, | |
| "device": DEVICE, | |
| "max_length": MAX_LENGTH | |
| } | |
| if model is not None and tokenizer is not None: | |
| if hasattr(model, 'config'): | |
| model_info["model_type"] = "Qwen3-Embedding" | |
| model_info["embedding_dimension"] = getattr(model.config, 'hidden_size', 1024) | |
| model_info["tokenizer_loaded"] = True | |
| else: | |
| model_info["model_type"] = "Unknown" | |
| model_info["embedding_dimension"] = "Unknown" | |
| model_info["tokenizer_loaded"] = False | |
| else: | |
| model_info["model_type"] = "Not Loaded" | |
| model_info["embedding_dimension"] = "N/A" | |
| model_info["tokenizer_loaded"] = tokenizer is not None | |
| return model_info | |
| # Create FastAPI application | |
| app = FastAPI( | |
| title="Qwen3 Embedding API", | |
| description="A stable API for generating text embeddings using the Qwen3-Embedding-0.6B model", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # FastAPI endpoints | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "Qwen3 Embedding API", | |
| "version": "1.0.0", | |
| "model": "Qwen3-Embedding-0.6B", | |
| "endpoints": { | |
| "health": "/health", | |
| "predict": "/api/predict", | |
| "docs": "/docs" | |
| } | |
| } | |
| async def health(): | |
| """Health check endpoint""" | |
| return health_check() | |
| async def predict(data: dict): | |
| """Main prediction endpoint for embeddings""" | |
| try: | |
| # Check for new format first (texts parameter) | |
| if "texts" in data: | |
| texts = data["texts"] | |
| normalize = data.get("normalize", True) | |
| if not isinstance(texts, list): | |
| raise HTTPException(status_code=400, detail="'texts' must be a list") | |
| if len(texts) == 0: | |
| raise HTTPException(status_code=400, detail="'texts' list cannot be empty") | |
| # Generate embeddings | |
| logger.info(f"Generating embeddings for {len(texts)} texts") | |
| embeddings = generate_embeddings(texts) | |
| logger.info(f"Generated {len(embeddings)} embeddings with dimension {len(embeddings[0]) if embeddings else 0}") | |
| # Normalize embeddings if requested | |
| if normalize: | |
| import numpy as np | |
| try: | |
| embeddings = [emb / np.linalg.norm(emb) for emb in embeddings] | |
| logger.info("Embeddings normalized") | |
| except Exception as norm_error: | |
| logger.warning(f"Normalization failed: {str(norm_error)}, returning unnormalized embeddings") | |
| # Continue with unnormalized embeddings | |
| return { | |
| "embeddings": embeddings, | |
| "model": MODEL_NAME, | |
| "usage": { | |
| "prompt_tokens": sum(len(text.split()) for text in texts), | |
| "total_tokens": sum(len(text.split()) for text in texts) | |
| } | |
| } | |
| # Fallback to old format for backward compatibility | |
| elif "data" in data: | |
| input_data = data["data"] | |
| # Handle single text or batch texts | |
| if isinstance(input_data, str): | |
| # Single text | |
| embeddings = generate_embeddings(input_data) | |
| return {"data": [embeddings]} | |
| elif isinstance(input_data, list): | |
| if len(input_data) > 0 and isinstance(input_data[0], str): | |
| # Single text in list | |
| embeddings = generate_embeddings(input_data[0]) | |
| return {"data": [embeddings]} | |
| elif len(input_data) > 0 and isinstance(input_data[0], list): | |
| # Batch texts | |
| embeddings = generate_embeddings(input_data[0]) | |
| return {"data": [embeddings]} | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid data format") | |
| else: | |
| raise HTTPException(status_code=400, detail="Invalid data type") | |
| else: | |
| raise HTTPException(status_code=400, detail="Missing 'texts' or 'data' field in request") | |
| except Exception as e: | |
| logger.error(f"Error in predict endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def similarity(data: dict): | |
| """Compute similarity between two texts or embeddings""" | |
| try: | |
| # Check for new format first (text1, text2 parameters) | |
| if "text1" in data and "text2" in data: | |
| text1 = data["text1"] | |
| text2 = data["text2"] | |
| if not isinstance(text1, str) or not isinstance(text2, str): | |
| raise HTTPException(status_code=400, detail="text1 and text2 must be strings") | |
| # Generate embeddings for both texts | |
| emb1 = generate_embeddings(text1) | |
| emb2 = generate_embeddings(text2) | |
| # Compute similarity | |
| sim = compute_similarity(emb1, emb2) | |
| return { | |
| "similarity": sim, | |
| "model": MODEL_NAME, | |
| "text1": text1, | |
| "text2": text2 | |
| } | |
| # Fallback to old format (embedding1, embedding2 parameters) | |
| elif "embedding1" in data and "embedding2" in data: | |
| emb1 = data["embedding1"] | |
| emb2 = data["embedding2"] | |
| if not isinstance(emb1, list) or not isinstance(emb2, list): | |
| raise HTTPException(status_code=400, detail="Embeddings must be lists") | |
| sim = compute_similarity(emb1, emb2) | |
| return {"similarity": sim} | |
| else: | |
| raise HTTPException(status_code=400, detail="Missing 'text1' and 'text2' or 'embedding1' and 'embedding2' fields") | |
| except Exception as e: | |
| logger.error(f"Error in similarity endpoint: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| def main(): | |
| """Main function to run the application""" | |
| logger.info("Starting Qwen3 Embedding Model API...") | |
| # Load model | |
| if not load_model(): | |
| logger.error("Failed to load model. Exiting...") | |
| return | |
| logger.info("Model loaded successfully. Starting FastAPI server...") | |
| # Run with uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |