Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel | |
| from typing import List, Union | |
| import json | |
| import logging | |
| import os | |
| from sentence_transformers import SentenceTransformer | |
| import time | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Model configuration | |
| MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Qwen3 Embedding model | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_LENGTH = 512 | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| sentence_transformer = None | |
| def load_model(): | |
| """Load the Qwen model and tokenizer""" | |
| global model, tokenizer, sentence_transformer | |
| try: | |
| logger.info(f"Loading model on device: {DEVICE}") | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModel.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" if DEVICE == "cuda" else None | |
| ) | |
| if DEVICE == "cpu": | |
| model = model.to(DEVICE) | |
| model.eval() | |
| # Also load sentence transformer as backup | |
| sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2') | |
| logger.info("Model loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| return False | |
| def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: | |
| """Generate embeddings for input text(s) using Qwen3 Embedding model""" | |
| global model, tokenizer, sentence_transformer | |
| try: | |
| # Ensure texts is a list | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| single_text = True | |
| else: | |
| single_text = False | |
| # Truncate texts if too long | |
| texts = [text[:MAX_LENGTH] for text in texts] | |
| embeddings = [] | |
| for text in texts: | |
| try: | |
| # Method 1: Try using the Qwen3 embedding model directly | |
| if model and tokenizer: | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=MAX_LENGTH | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # For Qwen3 embedding model, use the pooled output | |
| if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: | |
| embedding = outputs.pooler_output.squeeze().cpu().numpy() | |
| else: | |
| # Fallback to mean pooling of last hidden state | |
| embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() | |
| embeddings.append(embedding.tolist()) | |
| else: | |
| # Method 2: Fallback to sentence transformer | |
| if sentence_transformer: | |
| embedding = sentence_transformer.encode(text) | |
| embeddings.append(embedding.tolist()) | |
| else: | |
| raise Exception("No model available") | |
| except Exception as e: | |
| logger.warning(f"Error generating embedding for text: {str(e)}") | |
| # Fallback to sentence transformer | |
| if sentence_transformer: | |
| embedding = sentence_transformer.encode(text) | |
| embeddings.append(embedding.tolist()) | |
| else: | |
| # Return zero vector as last resort | |
| embeddings.append([0.0] * 1024) # Qwen3-Embedding-0.6B has 1024 dimensions | |
| return embeddings[0] if single_text else embeddings | |
| except Exception as e: | |
| logger.error(f"Error in generate_embeddings: {str(e)}") | |
| # Return zero vectors as fallback | |
| if single_text: | |
| return [0.0] * 1024 | |
| else: | |
| return [[0.0] * 1024] * len(texts) | |
| def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float: | |
| """Compute cosine similarity between two embeddings""" | |
| try: | |
| # Convert to numpy arrays | |
| emb1 = np.array(embedding1) | |
| emb2 = np.array(embedding2) | |
| # Compute cosine similarity | |
| dot_product = np.dot(emb1, emb2) | |
| norm1 = np.linalg.norm(emb1) | |
| norm2 = np.linalg.norm(emb2) | |
| if norm1 == 0 or norm2 == 0: | |
| return 0.0 | |
| similarity = dot_product / (norm1 * norm2) | |
| return float(similarity) | |
| except Exception as e: | |
| logger.error(f"Error computing similarity: {str(e)}") | |
| return 0.0 | |
| def batch_embedding_interface(texts: str) -> str: | |
| """Interface for batch embedding generation""" | |
| try: | |
| # Split texts by newlines | |
| text_list = [text.strip() for text in texts.split('\n') if text.strip()] | |
| if not text_list: | |
| return json.dumps([]) | |
| # Generate embeddings | |
| embeddings = generate_embeddings(text_list) | |
| # Return as JSON string | |
| return json.dumps(embeddings) | |
| except Exception as e: | |
| logger.error(f"Error in batch_embedding_interface: {str(e)}") | |
| return json.dumps([]) | |
| def single_embedding_interface(text: str) -> str: | |
| """Interface for single embedding generation""" | |
| try: | |
| if not text.strip(): | |
| return json.dumps([]) | |
| # Generate embedding | |
| embedding = generate_embeddings(text) | |
| # Return as JSON string | |
| return json.dumps(embedding) | |
| except Exception as e: | |
| logger.error(f"Error in single_embedding_interface: {str(e)}") | |
| return json.dumps([]) | |
| def similarity_interface(embedding1: str, embedding2: str) -> float: | |
| """Interface for computing similarity between two embeddings""" | |
| try: | |
| # Parse embeddings from JSON strings | |
| emb1 = json.loads(embedding1) | |
| emb2 = json.loads(embedding2) | |
| # Compute similarity | |
| similarity = compute_similarity(emb1, emb2) | |
| return similarity | |
| except Exception as e: | |
| logger.error(f"Error in similarity_interface: {str(e)}") | |
| return 0.0 | |
| def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "model_loaded": model is not None} | |
| # Create Gradio interface | |
| def create_interface(): | |
| """Create the Gradio interface""" | |
| with gr.Blocks( | |
| title="Qwen Embedding Model", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| margin: auto !important; | |
| } | |
| """ | |
| ) as interface: | |
| gr.Markdown(""" | |
| # Qwen Embedding Model API | |
| This space provides a stable API for generating text embeddings using the Qwen model. | |
| The API supports both single text and batch processing. | |
| """) | |
| with gr.Tab("Single Text Embedding"): | |
| gr.Markdown("Generate embedding for a single text input.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| single_text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter text to generate embedding...", | |
| lines=3 | |
| ) | |
| single_btn = gr.Button("Generate Embedding", variant="primary") | |
| with gr.Column(): | |
| single_output = gr.Textbox( | |
| label="Embedding (JSON)", | |
| lines=10, | |
| interactive=False | |
| ) | |
| single_btn.click( | |
| single_embedding_interface, | |
| inputs=[single_text_input], | |
| outputs=[single_output] | |
| ) | |
| with gr.Tab("Batch Text Embedding"): | |
| gr.Markdown("Generate embeddings for multiple texts (one per line).") | |
| with gr.Row(): | |
| with gr.Column(): | |
| batch_text_input = gr.Textbox( | |
| label="Input Texts (one per line)", | |
| placeholder="Enter multiple texts, one per line...", | |
| lines=5 | |
| ) | |
| batch_btn = gr.Button("Generate Embeddings", variant="primary") | |
| with gr.Column(): | |
| batch_output = gr.Textbox( | |
| label="Embeddings (JSON)", | |
| lines=10, | |
| interactive=False | |
| ) | |
| batch_btn.click( | |
| batch_embedding_interface, | |
| inputs=[batch_text_input], | |
| outputs=[batch_output] | |
| ) | |
| with gr.Tab("Similarity Calculator"): | |
| gr.Markdown("Compute cosine similarity between two embeddings.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| emb1_input = gr.Textbox( | |
| label="Embedding 1 (JSON)", | |
| placeholder='["0.1", "0.2", ...]', | |
| lines=3 | |
| ) | |
| emb2_input = gr.Textbox( | |
| label="Embedding 2 (JSON)", | |
| placeholder='["0.1", "0.2", ...]', | |
| lines=3 | |
| ) | |
| sim_btn = gr.Button("Compute Similarity", variant="primary") | |
| with gr.Column(): | |
| similarity_output = gr.Number( | |
| label="Cosine Similarity", | |
| precision=4 | |
| ) | |
| sim_btn.click( | |
| similarity_interface, | |
| inputs=[emb1_input, emb2_input], | |
| outputs=[similarity_output] | |
| ) | |
| with gr.Tab("API Documentation"): | |
| gr.Markdown(""" | |
| ## API Endpoints | |
| ### 1. Single Text Embedding | |
| **POST** `/api/predict` | |
| ```json | |
| { | |
| "data": ["Your text here"] | |
| } | |
| ``` | |
| ### 2. Batch Text Embedding | |
| **POST** `/api/predict` | |
| ```json | |
| { | |
| "data": [["Text 1", "Text 2", "Text 3"]] | |
| } | |
| ``` | |
| ### 3. Health Check | |
| **GET** `/health` | |
| Returns: `{"status": "healthy", "model_loaded": true}` | |
| ## Response Format | |
| All endpoints return embeddings as JSON arrays of floating-point numbers. | |
| """) | |
| return interface | |
| def main(): | |
| """Main function to run the application""" | |
| logger.info("Starting Qwen Embedding Model API...") | |
| # Load model | |
| if not load_model(): | |
| logger.error("Failed to load model. Exiting...") | |
| return | |
| # Create and launch interface | |
| interface = create_interface() | |
| # Launch with public access | |
| interface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| quiet=False | |
| ) | |
| if __name__ == "__main__": | |
| main() | |