Ojochegbeng commited on
Commit
8ede5e9
·
verified ·
1 Parent(s): 24131c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +353 -358
app.py CHANGED
@@ -1,358 +1,353 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- from transformers import AutoTokenizer, AutoModel
5
- from typing import List, Union
6
- import json
7
- import logging
8
- import os
9
- from sentence_transformers import SentenceTransformer
10
- import time
11
-
12
- # Configure logging
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
-
16
- # Model configuration
17
- MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Qwen3 Embedding model
18
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
- MAX_LENGTH = 512
20
-
21
- # Global variables for model and tokenizer
22
- model = None
23
- tokenizer = None
24
- sentence_transformer = None
25
-
26
- def load_model():
27
- """Load the Qwen model and tokenizer"""
28
- global model, tokenizer, sentence_transformer
29
-
30
- try:
31
- logger.info(f"Loading model on device: {DEVICE}")
32
-
33
- # Load tokenizer and model
34
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
35
- model = AutoModel.from_pretrained(
36
- MODEL_NAME,
37
- trust_remote_code=True,
38
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
39
- device_map="auto" if DEVICE == "cuda" else None
40
- )
41
-
42
- if DEVICE == "cpu":
43
- model = model.to(DEVICE)
44
-
45
- model.eval()
46
-
47
- # Also load sentence transformer as backup
48
- sentence_transformer = SentenceTransformer('all-MiniLM-L6-v2')
49
-
50
- logger.info("Model loaded successfully")
51
- return True
52
-
53
- except Exception as e:
54
- logger.error(f"Error loading model: {str(e)}")
55
- return False
56
-
57
- def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
58
- """Generate embeddings for input text(s) using Qwen3 Embedding model"""
59
- global model, tokenizer, sentence_transformer
60
-
61
- try:
62
- # Ensure texts is a list
63
- if isinstance(texts, str):
64
- texts = [texts]
65
- single_text = True
66
- else:
67
- single_text = False
68
-
69
- # Truncate texts if too long
70
- texts = [text[:MAX_LENGTH] for text in texts]
71
-
72
- embeddings = []
73
-
74
- for text in texts:
75
- try:
76
- # Method 1: Try using the Qwen3 embedding model directly
77
- if model and tokenizer:
78
- inputs = tokenizer(
79
- text,
80
- return_tensors="pt",
81
- padding=True,
82
- truncation=True,
83
- max_length=MAX_LENGTH
84
- ).to(DEVICE)
85
-
86
- with torch.no_grad():
87
- outputs = model(**inputs)
88
- # For Qwen3 embedding model, use the pooled output
89
- if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
90
- embedding = outputs.pooler_output.squeeze().cpu().numpy()
91
- else:
92
- # Fallback to mean pooling of last hidden state
93
- embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
94
- embeddings.append(embedding.tolist())
95
-
96
- else:
97
- # Method 2: Fallback to sentence transformer
98
- if sentence_transformer:
99
- embedding = sentence_transformer.encode(text)
100
- embeddings.append(embedding.tolist())
101
- else:
102
- raise Exception("No model available")
103
-
104
- except Exception as e:
105
- logger.warning(f"Error generating embedding for text: {str(e)}")
106
- # Fallback to sentence transformer
107
- if sentence_transformer:
108
- embedding = sentence_transformer.encode(text)
109
- embeddings.append(embedding.tolist())
110
- else:
111
- # Return zero vector as last resort
112
- embeddings.append([0.0] * 1024) # Qwen3-Embedding-0.6B has 1024 dimensions
113
-
114
- return embeddings[0] if single_text else embeddings
115
-
116
- except Exception as e:
117
- logger.error(f"Error in generate_embeddings: {str(e)}")
118
- # Return zero vectors as fallback
119
- if single_text:
120
- return [0.0] * 1024
121
- else:
122
- return [[0.0] * 1024] * len(texts)
123
-
124
- def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float:
125
- """Compute cosine similarity between two embeddings"""
126
- try:
127
- # Convert to numpy arrays
128
- emb1 = np.array(embedding1)
129
- emb2 = np.array(embedding2)
130
-
131
- # Compute cosine similarity
132
- dot_product = np.dot(emb1, emb2)
133
- norm1 = np.linalg.norm(emb1)
134
- norm2 = np.linalg.norm(emb2)
135
-
136
- if norm1 == 0 or norm2 == 0:
137
- return 0.0
138
-
139
- similarity = dot_product / (norm1 * norm2)
140
- return float(similarity)
141
-
142
- except Exception as e:
143
- logger.error(f"Error computing similarity: {str(e)}")
144
- return 0.0
145
-
146
- def batch_embedding_interface(texts: str) -> str:
147
- """Interface for batch embedding generation"""
148
- try:
149
- # Split texts by newlines
150
- text_list = [text.strip() for text in texts.split('\n') if text.strip()]
151
-
152
- if not text_list:
153
- return json.dumps([])
154
-
155
- # Generate embeddings
156
- embeddings = generate_embeddings(text_list)
157
-
158
- # Return as JSON string
159
- return json.dumps(embeddings)
160
-
161
- except Exception as e:
162
- logger.error(f"Error in batch_embedding_interface: {str(e)}")
163
- return json.dumps([])
164
-
165
- def single_embedding_interface(text: str) -> str:
166
- """Interface for single embedding generation"""
167
- try:
168
- if not text.strip():
169
- return json.dumps([])
170
-
171
- # Generate embedding
172
- embedding = generate_embeddings(text)
173
-
174
- # Return as JSON string
175
- return json.dumps(embedding)
176
-
177
- except Exception as e:
178
- logger.error(f"Error in single_embedding_interface: {str(e)}")
179
- return json.dumps([])
180
-
181
- def similarity_interface(embedding1: str, embedding2: str) -> float:
182
- """Interface for computing similarity between two embeddings"""
183
- try:
184
- # Parse embeddings from JSON strings
185
- emb1 = json.loads(embedding1)
186
- emb2 = json.loads(embedding2)
187
-
188
- # Compute similarity
189
- similarity = compute_similarity(emb1, emb2)
190
-
191
- return similarity
192
-
193
- except Exception as e:
194
- logger.error(f"Error in similarity_interface: {str(e)}")
195
- return 0.0
196
-
197
- def health_check():
198
- """Health check endpoint"""
199
- return {"status": "healthy", "model_loaded": model is not None}
200
-
201
- # Create Gradio interface
202
- def create_interface():
203
- """Create the Gradio interface"""
204
-
205
- with gr.Blocks(
206
- title="Qwen Embedding Model",
207
- theme=gr.themes.Soft(),
208
- css="""
209
- .gradio-container {
210
- max-width: 1200px !important;
211
- margin: auto !important;
212
- }
213
- """
214
- ) as interface:
215
-
216
- gr.Markdown("""
217
- # Qwen Embedding Model API
218
-
219
- This space provides a stable API for generating text embeddings using the Qwen model.
220
- The API supports both single text and batch processing.
221
- """)
222
-
223
- with gr.Tab("Single Text Embedding"):
224
- gr.Markdown("Generate embedding for a single text input.")
225
-
226
- with gr.Row():
227
- with gr.Column():
228
- single_text_input = gr.Textbox(
229
- label="Input Text",
230
- placeholder="Enter text to generate embedding...",
231
- lines=3
232
- )
233
- single_btn = gr.Button("Generate Embedding", variant="primary")
234
-
235
- with gr.Column():
236
- single_output = gr.Textbox(
237
- label="Embedding (JSON)",
238
- lines=10,
239
- interactive=False
240
- )
241
-
242
- single_btn.click(
243
- single_embedding_interface,
244
- inputs=[single_text_input],
245
- outputs=[single_output]
246
- )
247
-
248
- with gr.Tab("Batch Text Embedding"):
249
- gr.Markdown("Generate embeddings for multiple texts (one per line).")
250
-
251
- with gr.Row():
252
- with gr.Column():
253
- batch_text_input = gr.Textbox(
254
- label="Input Texts (one per line)",
255
- placeholder="Enter multiple texts, one per line...",
256
- lines=5
257
- )
258
- batch_btn = gr.Button("Generate Embeddings", variant="primary")
259
-
260
- with gr.Column():
261
- batch_output = gr.Textbox(
262
- label="Embeddings (JSON)",
263
- lines=10,
264
- interactive=False
265
- )
266
-
267
- batch_btn.click(
268
- batch_embedding_interface,
269
- inputs=[batch_text_input],
270
- outputs=[batch_output]
271
- )
272
-
273
- with gr.Tab("Similarity Calculator"):
274
- gr.Markdown("Compute cosine similarity between two embeddings.")
275
-
276
- with gr.Row():
277
- with gr.Column():
278
- emb1_input = gr.Textbox(
279
- label="Embedding 1 (JSON)",
280
- placeholder='["0.1", "0.2", ...]',
281
- lines=3
282
- )
283
- emb2_input = gr.Textbox(
284
- label="Embedding 2 (JSON)",
285
- placeholder='["0.1", "0.2", ...]',
286
- lines=3
287
- )
288
- sim_btn = gr.Button("Compute Similarity", variant="primary")
289
-
290
- with gr.Column():
291
- similarity_output = gr.Number(
292
- label="Cosine Similarity",
293
- precision=4
294
- )
295
-
296
- sim_btn.click(
297
- similarity_interface,
298
- inputs=[emb1_input, emb2_input],
299
- outputs=[similarity_output]
300
- )
301
-
302
- with gr.Tab("API Documentation"):
303
- gr.Markdown("""
304
- ## API Endpoints
305
-
306
- ### 1. Single Text Embedding
307
- **POST** `/api/predict`
308
-
309
- ```json
310
- {
311
- "data": ["Your text here"]
312
- }
313
- ```
314
-
315
- ### 2. Batch Text Embedding
316
- **POST** `/api/predict`
317
-
318
- ```json
319
- {
320
- "data": [["Text 1", "Text 2", "Text 3"]]
321
- }
322
- ```
323
-
324
- ### 3. Health Check
325
- **GET** `/health`
326
-
327
- Returns: `{"status": "healthy", "model_loaded": true}`
328
-
329
- ## Response Format
330
-
331
- All endpoints return embeddings as JSON arrays of floating-point numbers.
332
- """)
333
-
334
- return interface
335
-
336
- def main():
337
- """Main function to run the application"""
338
- logger.info("Starting Qwen Embedding Model API...")
339
-
340
- # Load model
341
- if not load_model():
342
- logger.error("Failed to load model. Exiting...")
343
- return
344
-
345
- # Create and launch interface
346
- interface = create_interface()
347
-
348
- # Launch with public access
349
- interface.launch(
350
- server_name="0.0.0.0",
351
- server_port=7860,
352
- share=False,
353
- show_error=True,
354
- quiet=False
355
- )
356
-
357
- if __name__ == "__main__":
358
- main()
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from typing import List, Union
6
+ import json
7
+ import logging
8
+ import os
9
+ import time
10
+
11
+ # Configure logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Model configuration
16
+ MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" # Qwen3 Embedding model
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ MAX_LENGTH = 512
19
+
20
+ # Global variables for model and tokenizer
21
+ model = None
22
+ tokenizer = None
23
+
24
+ def load_model():
25
+ """Load the Qwen3 embedding model and tokenizer"""
26
+ global model, tokenizer
27
+
28
+ try:
29
+ logger.info(f"Loading Qwen3 embedding model on device: {DEVICE}")
30
+
31
+ # Load tokenizer and model for Qwen3 embedding
32
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
33
+ model = AutoModel.from_pretrained(
34
+ MODEL_NAME,
35
+ trust_remote_code=True,
36
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
37
+ device_map="auto" if DEVICE == "cuda" else None
38
+ )
39
+
40
+ if DEVICE == "cpu":
41
+ model = model.to(DEVICE)
42
+
43
+ model.eval()
44
+
45
+ logger.info("Qwen3 embedding model loaded successfully")
46
+ return True
47
+
48
+ except Exception as e:
49
+ logger.error(f"Error loading Qwen3 model: {str(e)}")
50
+ # Try fallback to a simpler approach
51
+ try:
52
+ logger.info("Trying fallback model loading...")
53
+ from sentence_transformers import SentenceTransformer
54
+ model = SentenceTransformer('all-MiniLM-L6-v2')
55
+ tokenizer = None
56
+ logger.info("Fallback model loaded successfully")
57
+ return True
58
+ except Exception as fallback_error:
59
+ logger.error(f"Fallback model loading also failed: {str(fallback_error)}")
60
+ return False
61
+
62
+ def generate_embeddings(texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
63
+ """Generate embeddings for input text(s) using Qwen3 or fallback model"""
64
+ global model, tokenizer
65
+
66
+ try:
67
+ # Ensure texts is a list
68
+ if isinstance(texts, str):
69
+ texts = [texts]
70
+ single_text = True
71
+ else:
72
+ single_text = False
73
+
74
+ # Truncate texts if too long
75
+ texts = [text[:MAX_LENGTH] for text in texts]
76
+
77
+ embeddings = []
78
+
79
+ for text in texts:
80
+ try:
81
+ # Method 1: Try using the Qwen model directly
82
+ if model and tokenizer:
83
+ inputs = tokenizer(
84
+ text,
85
+ return_tensors="pt",
86
+ padding=True,
87
+ truncation=True,
88
+ max_length=MAX_LENGTH
89
+ ).to(DEVICE)
90
+
91
+ with torch.no_grad():
92
+ outputs = model(**inputs)
93
+ # Use mean pooling of last hidden state
94
+ embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
95
+ embeddings.append(embedding.tolist())
96
+
97
+ elif model and hasattr(model, 'encode'):
98
+ # Method 2: Using sentence transformer fallback
99
+ embedding = model.encode(text)
100
+ embeddings.append(embedding.tolist())
101
+ else:
102
+ raise Exception("No model available")
103
+
104
+ except Exception as e:
105
+ logger.warning(f"Error generating embedding for text: {str(e)}")
106
+ # Return zero vector as last resort
107
+ embeddings.append([0.0] * 384) # Standard dimension for fallback
108
+
109
+ return embeddings[0] if single_text else embeddings
110
+
111
+ except Exception as e:
112
+ logger.error(f"Error in generate_embeddings: {str(e)}")
113
+ # Return zero vectors as fallback
114
+ if single_text:
115
+ return [0.0] * 384
116
+ else:
117
+ return [[0.0] * 384] * len(texts)
118
+
119
+ def compute_similarity(embedding1: List[float], embedding2: List[float]) -> float:
120
+ """Compute cosine similarity between two embeddings"""
121
+ try:
122
+ # Convert to numpy arrays
123
+ emb1 = np.array(embedding1)
124
+ emb2 = np.array(embedding2)
125
+
126
+ # Compute cosine similarity
127
+ dot_product = np.dot(emb1, emb2)
128
+ norm1 = np.linalg.norm(emb1)
129
+ norm2 = np.linalg.norm(emb2)
130
+
131
+ if norm1 == 0 or norm2 == 0:
132
+ return 0.0
133
+
134
+ similarity = dot_product / (norm1 * norm2)
135
+ return float(similarity)
136
+
137
+ except Exception as e:
138
+ logger.error(f"Error computing similarity: {str(e)}")
139
+ return 0.0
140
+
141
+ def batch_embedding_interface(texts: str) -> str:
142
+ """Interface for batch embedding generation"""
143
+ try:
144
+ # Split texts by newlines
145
+ text_list = [text.strip() for text in texts.split('\n') if text.strip()]
146
+
147
+ if not text_list:
148
+ return json.dumps([])
149
+
150
+ # Generate embeddings
151
+ embeddings = generate_embeddings(text_list)
152
+
153
+ # Return as JSON string
154
+ return json.dumps(embeddings)
155
+
156
+ except Exception as e:
157
+ logger.error(f"Error in batch_embedding_interface: {str(e)}")
158
+ return json.dumps([])
159
+
160
+ def single_embedding_interface(text: str) -> str:
161
+ """Interface for single embedding generation"""
162
+ try:
163
+ if not text.strip():
164
+ return json.dumps([])
165
+
166
+ # Generate embedding
167
+ embedding = generate_embeddings(text)
168
+
169
+ # Return as JSON string
170
+ return json.dumps(embedding)
171
+
172
+ except Exception as e:
173
+ logger.error(f"Error in single_embedding_interface: {str(e)}")
174
+ return json.dumps([])
175
+
176
+ def similarity_interface(embedding1: str, embedding2: str) -> float:
177
+ """Interface for computing similarity between two embeddings"""
178
+ try:
179
+ # Parse embeddings from JSON strings
180
+ emb1 = json.loads(embedding1)
181
+ emb2 = json.loads(embedding2)
182
+
183
+ # Compute similarity
184
+ similarity = compute_similarity(emb1, emb2)
185
+
186
+ return similarity
187
+
188
+ except Exception as e:
189
+ logger.error(f"Error in similarity_interface: {str(e)}")
190
+ return 0.0
191
+
192
+ def health_check():
193
+ """Health check endpoint"""
194
+ return {"status": "healthy", "model_loaded": model is not None}
195
+
196
+ # Create Gradio interface
197
+ def create_interface():
198
+ """Create the Gradio interface"""
199
+
200
+ with gr.Blocks(
201
+ title="Qwen Embedding Model",
202
+ theme=gr.themes.Soft(),
203
+ css="""
204
+ .gradio-container {
205
+ max-width: 1200px !important;
206
+ margin: auto !important;
207
+ }
208
+ """
209
+ ) as interface:
210
+
211
+ gr.Markdown("""
212
+ # Qwen Embedding Model API
213
+
214
+ This space provides a stable API for generating text embeddings using the Qwen model.
215
+ The API supports both single text and batch processing.
216
+ """)
217
+
218
+ with gr.Tab("Single Text Embedding"):
219
+ gr.Markdown("Generate embedding for a single text input.")
220
+
221
+ with gr.Row():
222
+ with gr.Column():
223
+ single_text_input = gr.Textbox(
224
+ label="Input Text",
225
+ placeholder="Enter text to generate embedding...",
226
+ lines=3
227
+ )
228
+ single_btn = gr.Button("Generate Embedding", variant="primary")
229
+
230
+ with gr.Column():
231
+ single_output = gr.Textbox(
232
+ label="Embedding (JSON)",
233
+ lines=10,
234
+ interactive=False
235
+ )
236
+
237
+ single_btn.click(
238
+ single_embedding_interface,
239
+ inputs=[single_text_input],
240
+ outputs=[single_output]
241
+ )
242
+
243
+ with gr.Tab("Batch Text Embedding"):
244
+ gr.Markdown("Generate embeddings for multiple texts (one per line).")
245
+
246
+ with gr.Row():
247
+ with gr.Column():
248
+ batch_text_input = gr.Textbox(
249
+ label="Input Texts (one per line)",
250
+ placeholder="Enter multiple texts, one per line...",
251
+ lines=5
252
+ )
253
+ batch_btn = gr.Button("Generate Embeddings", variant="primary")
254
+
255
+ with gr.Column():
256
+ batch_output = gr.Textbox(
257
+ label="Embeddings (JSON)",
258
+ lines=10,
259
+ interactive=False
260
+ )
261
+
262
+ batch_btn.click(
263
+ batch_embedding_interface,
264
+ inputs=[batch_text_input],
265
+ outputs=[batch_output]
266
+ )
267
+
268
+ with gr.Tab("Similarity Calculator"):
269
+ gr.Markdown("Compute cosine similarity between two embeddings.")
270
+
271
+ with gr.Row():
272
+ with gr.Column():
273
+ emb1_input = gr.Textbox(
274
+ label="Embedding 1 (JSON)",
275
+ placeholder='["0.1", "0.2", ...]',
276
+ lines=3
277
+ )
278
+ emb2_input = gr.Textbox(
279
+ label="Embedding 2 (JSON)",
280
+ placeholder='["0.1", "0.2", ...]',
281
+ lines=3
282
+ )
283
+ sim_btn = gr.Button("Compute Similarity", variant="primary")
284
+
285
+ with gr.Column():
286
+ similarity_output = gr.Number(
287
+ label="Cosine Similarity",
288
+ precision=4
289
+ )
290
+
291
+ sim_btn.click(
292
+ similarity_interface,
293
+ inputs=[emb1_input, emb2_input],
294
+ outputs=[similarity_output]
295
+ )
296
+
297
+ with gr.Tab("API Documentation"):
298
+ gr.Markdown("""
299
+ ## API Endpoints
300
+
301
+ ### 1. Single Text Embedding
302
+ **POST** `/api/predict`
303
+
304
+ ```json
305
+ {
306
+ "data": ["Your text here"]
307
+ }
308
+ ```
309
+
310
+ ### 2. Batch Text Embedding
311
+ **POST** `/api/predict`
312
+
313
+ ```json
314
+ {
315
+ "data": [["Text 1", "Text 2", "Text 3"]]
316
+ }
317
+ ```
318
+
319
+ ### 3. Health Check
320
+ **GET** `/health`
321
+
322
+ Returns: `{"status": "healthy", "model_loaded": true}`
323
+
324
+ ## Response Format
325
+
326
+ All endpoints return embeddings as JSON arrays of floating-point numbers.
327
+ """)
328
+
329
+ return interface
330
+
331
+ def main():
332
+ """Main function to run the application"""
333
+ logger.info("Starting Qwen Embedding Model API...")
334
+
335
+ # Load model
336
+ if not load_model():
337
+ logger.error("Failed to load model. Exiting...")
338
+ return
339
+
340
+ # Create and launch interface
341
+ interface = create_interface()
342
+
343
+ # Launch with public access
344
+ interface.launch(
345
+ server_name="0.0.0.0",
346
+ server_port=7860,
347
+ share=False,
348
+ show_error=True,
349
+ quiet=False
350
+ )
351
+
352
+ if __name__ == "__main__":
353
+ main()