File size: 15,135 Bytes
14cf01c
 
8ede5e9
 
3b74c11
8ede5e9
 
 
 
 
14cf01c
8ede5e9
 
 
 
 
6f6fd79
8ede5e9
 
 
 
 
 
 
 
 
 
 
 
 
9e9c055
8ede5e9
 
3b74c11
 
 
 
 
 
 
 
 
 
 
 
8ede5e9
 
 
 
 
 
 
 
 
 
 
 
9e9c055
 
 
 
 
3d61fba
3b74c11
9e9c055
 
8ede5e9
 
 
 
3d61fba
020892f
8ede5e9
 
9e9c055
8ede5e9
 
3d61fba
 
 
8ede5e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d61fba
8ede5e9
 
 
 
 
 
 
 
 
 
3d61fba
 
 
 
 
 
 
 
 
 
 
 
9e9c055
3d61fba
db3fd97
3d61fba
 
 
 
db3fd97
8ede5e9
 
3d61fba
 
8ede5e9
 
 
 
 
3d61fba
8ede5e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e9c055
3d61fba
 
9e9c055
 
 
 
 
3d61fba
9e9c055
 
 
3d61fba
9e9c055
 
 
3d61fba
 
 
 
 
9e9c055
 
8ede5e9
14cf01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ede5e9
14cf01c
 
 
 
 
 
 
 
 
 
 
020892f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68092ea
 
 
 
 
 
020892f
 
 
 
 
 
 
 
 
8ede5e9
020892f
 
 
 
 
 
 
 
14cf01c
020892f
 
 
 
 
 
 
 
 
 
 
14cf01c
020892f
14cf01c
020892f
8ede5e9
14cf01c
 
 
 
 
 
020892f
14cf01c
020892f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ede5e9
020892f
 
 
 
 
 
 
 
 
 
8ede5e9
020892f
 
14cf01c
 
 
 
8ede5e9
 
 
14cf01c
8ede5e9
 
 
 
 
 
14cf01c
8ede5e9
14cf01c
 
 
 
 
 
8ede5e9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel, AutoConfig
from typing import List, Union
import json
import logging
import os
import time
import uvicorn

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Model configuration - Qwen3 Embedding model
MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B"  # Qwen3 Embedding model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 512

# Global variables for model and tokenizer
model = None
tokenizer = None

def load_model():
    """Load the Qwen3 embedding model and tokenizer"""
    global model, tokenizer
    
    try:
        logger.info(f"Loading Qwen3-Embedding-0.6B model on device: {DEVICE}")
        
        # Load tokenizer and model for Qwen3 embedding
        # First, try to load the config to understand the model structure
        config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
        logger.info(f"Model config loaded: {config.model_type}")
        
        # Load tokenizer - try different approaches
        try:
            tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
        except Exception as tokenizer_error:
            logger.warning(f"Failed to load tokenizer with trust_remote_code=True: {tokenizer_error}")
            tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=False)
        
        # Load model
        model = AutoModel.from_pretrained(
            MODEL_NAME, 
            trust_remote_code=True,
            torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
            device_map="auto" if DEVICE == "cuda" else None
        )
        
        if DEVICE == "cpu":
            model = model.to(DEVICE)
        
        model.eval()
        
        # Test the model with a simple input
        test_input = tokenizer("test", return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(DEVICE)
        with torch.no_grad():
            test_output = model(**test_input)
            logger.info(f"Model test successful. Output shape: {test_output.last_hidden_state.shape}")
            logger.info(f"Model config hidden size: {model.config.hidden_size}")
            logger.info(f"Tokenizer vocab size: {tokenizer.vocab_size}")
        
        logger.info("Qwen3-Embedding-0.6B model loaded successfully")
        return True
        
    except Exception as e:
        logger.error(f"Error loading Qwen3 model: {str(e)}")
        logger.error("No fallback available - Qwen3 model is required")
        return False

def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
    """Generate embeddings for input text(s) using Qwen3-Embedding-0.6B model"""
    global model, tokenizer
    
    if not model or not tokenizer:
        raise Exception("Qwen3 model not loaded. Please ensure the model is properly loaded.")
    
    try:
        # Ensure texts is a list
        if isinstance(texts, str):
            texts = [texts]
            single_text = True
        else:
            single_text = False
        
        # Truncate texts if too long
        texts = [text[:MAX_LENGTH] for text in texts]
        
        embeddings = []
        
        for text in texts:
            try:
                # Use the Qwen3 embedding model directly
                    inputs = tokenizer(
                        text, 
                        return_tensors="pt", 
                        padding=True, 
                        truncation=True, 
                        max_length=MAX_LENGTH
                    ).to(DEVICE)
                    
                    with torch.no_grad():
                        outputs = model(**inputs)
                    
                    # For Qwen3 embedding models, use the last_hidden_state with mean pooling
                    if hasattr(outputs, 'last_hidden_state'):
                        # Mean pooling over the sequence length dimension
                        attention_mask = inputs.get('attention_mask', None)
                        if attention_mask is not None:
                            # Apply attention mask for proper mean pooling
                            token_embeddings = outputs.last_hidden_state
                            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
                            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                            embedding = (sum_embeddings / sum_mask).squeeze().cpu().numpy()
                        else:
                            # Simple mean pooling without attention mask
                            embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
                    else:
                        # Fallback to pooled output if available
                        embedding = outputs.pooler_output.squeeze().cpu().numpy()
                    
                    embeddings.append(embedding.tolist())
                        
            except Exception as e:
                logger.error(f"Error generating embedding for text: {str(e)}")
                raise Exception(f"Failed to generate embedding: {str(e)}")
        
        return embeddings[0] if single_text else embeddings
        
    except Exception as e:
        logger.error(f"Error in generate_embeddings: {str(e)}")
        raise Exception(f"Embedding generation failed: {str(e)}")

def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float:
    """Compute cosine similarity between two embeddings"""
    try:
        # Convert to numpy arrays
        emb1 = np.array(embedding1)
        emb2 = np.array(embedding2)
        
        # Compute cosine similarity
        dot_product = np.dot(emb1, emb2)
        norm1 = np.linalg.norm(emb1)
        norm2 = np.linalg.norm(emb2)
        
        if norm1 == 0 or norm2 == 0:
            return 0.0
        
        similarity = dot_product / (norm1 * norm2)
        return float(similarity)
        
    except Exception as e:
        logger.error(f"Error computing similarity: {str(e)}")
        return 0.0

def batch_embedding_interface(texts: str) -> str:
    """Interface for batch embedding generation"""
    try:
        # Split texts by newlines
        text_list = [text.strip() for text in texts.split('\n') if text.strip()]
        
        if not text_list:
            return json.dumps([])
        
        # Generate embeddings
        embeddings = generate_embeddings(text_list)
        
        # Return as JSON string
        return json.dumps(embeddings)
        
    except Exception as e:
        logger.error(f"Error in batch_embedding_interface: {str(e)}")
        return json.dumps([])

def single_embedding_interface(text: str) -> str:
    """Interface for single embedding generation"""
    try:
        if not text.strip():
            return json.dumps([])
        
        # Generate embedding
        embedding = generate_embeddings(text)
        
        # Return as JSON string
        return json.dumps(embedding)
        
    except Exception as e:
        logger.error(f"Error in single_embedding_interface: {str(e)}")
        return json.dumps([])

def similarity_interface(embedding1: str, embedding2: str) -> float:
    """Interface for computing similarity between two embeddings"""
    try:
        # Parse embeddings from JSON strings
        emb1 = json.loads(embedding1)
        emb2 = json.loads(embedding2)
        
        # Compute similarity
        similarity = compute_similarity(emb1, emb2)
        
        return similarity
        
    except Exception as e:
        logger.error(f"Error in similarity_interface: {str(e)}")
        return 0.0

def health_check():
    """Health check endpoint"""
    model_info = {
        "status": "healthy" if model is not None and tokenizer is not None else "unhealthy",
        "model_loaded": model is not None and tokenizer is not None,
        "model_name": MODEL_NAME,
        "device": DEVICE,
        "max_length": MAX_LENGTH
    }
    
    if model is not None and tokenizer is not None:
        if hasattr(model, 'config'):
            model_info["model_type"] = "Qwen3-Embedding"
            model_info["embedding_dimension"] = getattr(model.config, 'hidden_size', 1024)
            model_info["tokenizer_loaded"] = True
        else:
            model_info["model_type"] = "Unknown"
            model_info["embedding_dimension"] = "Unknown"
            model_info["tokenizer_loaded"] = False
    else:
        model_info["model_type"] = "Not Loaded"
        model_info["embedding_dimension"] = "N/A"
        model_info["tokenizer_loaded"] = tokenizer is not None
    
    return model_info

# Create FastAPI application
app = FastAPI(
    title="Qwen3 Embedding API",
    description="A stable API for generating text embeddings using the Qwen3-Embedding-0.6B model",
    version="1.0.0"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# FastAPI endpoints
@app.get("/")
async def root():
    """Root endpoint with API information"""
    return {
        "message": "Qwen3 Embedding API",
        "version": "1.0.0",
        "model": "Qwen3-Embedding-0.6B",
        "endpoints": {
            "health": "/health",
            "predict": "/api/predict",
            "docs": "/docs"
        }
    }

@app.get("/health")
async def health():
    """Health check endpoint"""
    return health_check()

@app.post("/api/predict")
async def predict(data: dict):
    """Main prediction endpoint for embeddings"""
    try:
        # Check for new format first (texts parameter)
        if "texts" in data:
            texts = data["texts"]
            normalize = data.get("normalize", True)
            
            if not isinstance(texts, list):
                raise HTTPException(status_code=400, detail="'texts' must be a list")
            
            if len(texts) == 0:
                raise HTTPException(status_code=400, detail="'texts' list cannot be empty")
            
            # Generate embeddings
            logger.info(f"Generating embeddings for {len(texts)} texts")
            embeddings = generate_embeddings(texts)
            logger.info(f"Generated {len(embeddings)} embeddings with dimension {len(embeddings[0]) if embeddings else 0}")
            
            # Normalize embeddings if requested
            if normalize:
                import numpy as np
                try:
                    embeddings = [emb / np.linalg.norm(emb) for emb in embeddings]
                    logger.info("Embeddings normalized")
                except Exception as norm_error:
                    logger.warning(f"Normalization failed: {str(norm_error)}, returning unnormalized embeddings")
                    # Continue with unnormalized embeddings
            
            return {
                "embeddings": embeddings,
                "model": MODEL_NAME,
                "usage": {
                    "prompt_tokens": sum(len(text.split()) for text in texts),
                    "total_tokens": sum(len(text.split()) for text in texts)
                }
            }
        
        # Fallback to old format for backward compatibility
        elif "data" in data:
            input_data = data["data"]
            
            # Handle single text or batch texts
            if isinstance(input_data, str):
                # Single text
                embeddings = generate_embeddings(input_data)
                return {"data": [embeddings]}
            elif isinstance(input_data, list):
                if len(input_data) > 0 and isinstance(input_data[0], str):
                    # Single text in list
                    embeddings = generate_embeddings(input_data[0])
                    return {"data": [embeddings]}
                elif len(input_data) > 0 and isinstance(input_data[0], list):
                    # Batch texts
                    embeddings = generate_embeddings(input_data[0])
                    return {"data": [embeddings]}
                else:
                    raise HTTPException(status_code=400, detail="Invalid data format")
            else:
                raise HTTPException(status_code=400, detail="Invalid data type")
        else:
            raise HTTPException(status_code=400, detail="Missing 'texts' or 'data' field in request")
            
    except Exception as e:
        logger.error(f"Error in predict endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

@app.post("/api/similarity")
async def similarity(data: dict):
    """Compute similarity between two texts or embeddings"""
    try:
        # Check for new format first (text1, text2 parameters)
        if "text1" in data and "text2" in data:
            text1 = data["text1"]
            text2 = data["text2"]
            
            if not isinstance(text1, str) or not isinstance(text2, str):
                raise HTTPException(status_code=400, detail="text1 and text2 must be strings")
            
            # Generate embeddings for both texts
            emb1 = generate_embeddings(text1)
            emb2 = generate_embeddings(text2)
            
            # Compute similarity
            sim = compute_similarity(emb1, emb2)
            return {
                "similarity": sim,
                "model": MODEL_NAME,
                "text1": text1,
                "text2": text2
            }
        
        # Fallback to old format (embedding1, embedding2 parameters)
        elif "embedding1" in data and "embedding2" in data:
            emb1 = data["embedding1"]
            emb2 = data["embedding2"]
            
            if not isinstance(emb1, list) or not isinstance(emb2, list):
                raise HTTPException(status_code=400, detail="Embeddings must be lists")
            
            sim = compute_similarity(emb1, emb2)
            return {"similarity": sim}
        
        else:
            raise HTTPException(status_code=400, detail="Missing 'text1' and 'text2' or 'embedding1' and 'embedding2' fields")
        
    except Exception as e:
        logger.error(f"Error in similarity endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

def main():
    """Main function to run the application"""
    logger.info("Starting Qwen3 Embedding Model API...")
    
    # Load model
    if not load_model():
        logger.error("Failed to load model. Exiting...")
        return
    
    logger.info("Model loaded successfully. Starting FastAPI server...")
    
    # Run with uvicorn
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=7860,
        log_level="info"
    )

if __name__ == "__main__":
    main()