Ojochegbeng commited on
Commit
14cf01c
·
verified ·
1 Parent(s): 8ede5e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -138
app.py CHANGED
@@ -1,4 +1,5 @@
1
- import gradio as gr
 
2
  import torch
3
  import numpy as np
4
  from transformers import AutoTokenizer, AutoModel
@@ -7,6 +8,7 @@ import json
7
  import logging
8
  import os
9
  import time
 
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
@@ -193,160 +195,111 @@ def health_check():
193
  """Health check endpoint"""
194
  return {"status": "healthy", "model_loaded": model is not None}
195
 
196
- # Create Gradio interface
197
- def create_interface():
198
- """Create the Gradio interface"""
199
-
200
- with gr.Blocks(
201
- title="Qwen Embedding Model",
202
- theme=gr.themes.Soft(),
203
- css="""
204
- .gradio-container {
205
- max-width: 1200px !important;
206
- margin: auto !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  }
208
- """
209
- ) as interface:
210
-
211
- gr.Markdown("""
212
- # Qwen Embedding Model API
 
 
 
 
 
 
 
 
213
 
214
- This space provides a stable API for generating text embeddings using the Qwen model.
215
- The API supports both single text and batch processing.
216
- """)
217
 
218
- with gr.Tab("Single Text Embedding"):
219
- gr.Markdown("Generate embedding for a single text input.")
220
-
221
- with gr.Row():
222
- with gr.Column():
223
- single_text_input = gr.Textbox(
224
- label="Input Text",
225
- placeholder="Enter text to generate embedding...",
226
- lines=3
227
- )
228
- single_btn = gr.Button("Generate Embedding", variant="primary")
229
-
230
- with gr.Column():
231
- single_output = gr.Textbox(
232
- label="Embedding (JSON)",
233
- lines=10,
234
- interactive=False
235
- )
236
 
237
- single_btn.click(
238
- single_embedding_interface,
239
- inputs=[single_text_input],
240
- outputs=[single_output]
241
- )
 
 
 
 
 
242
 
243
- with gr.Tab("Batch Text Embedding"):
244
- gr.Markdown("Generate embeddings for multiple texts (one per line).")
245
-
246
- with gr.Row():
247
- with gr.Column():
248
- batch_text_input = gr.Textbox(
249
- label="Input Texts (one per line)",
250
- placeholder="Enter multiple texts, one per line...",
251
- lines=5
252
- )
253
- batch_btn = gr.Button("Generate Embeddings", variant="primary")
254
-
255
- with gr.Column():
256
- batch_output = gr.Textbox(
257
- label="Embeddings (JSON)",
258
- lines=10,
259
- interactive=False
260
- )
261
-
262
- batch_btn.click(
263
- batch_embedding_interface,
264
- inputs=[batch_text_input],
265
- outputs=[batch_output]
266
- )
267
 
268
- with gr.Tab("Similarity Calculator"):
269
- gr.Markdown("Compute cosine similarity between two embeddings.")
270
-
271
- with gr.Row():
272
- with gr.Column():
273
- emb1_input = gr.Textbox(
274
- label="Embedding 1 (JSON)",
275
- placeholder='["0.1", "0.2", ...]',
276
- lines=3
277
- )
278
- emb2_input = gr.Textbox(
279
- label="Embedding 2 (JSON)",
280
- placeholder='["0.1", "0.2", ...]',
281
- lines=3
282
- )
283
- sim_btn = gr.Button("Compute Similarity", variant="primary")
284
-
285
- with gr.Column():
286
- similarity_output = gr.Number(
287
- label="Cosine Similarity",
288
- precision=4
289
- )
290
-
291
- sim_btn.click(
292
- similarity_interface,
293
- inputs=[emb1_input, emb2_input],
294
- outputs=[similarity_output]
295
- )
296
 
297
- with gr.Tab("API Documentation"):
298
- gr.Markdown("""
299
- ## API Endpoints
300
-
301
- ### 1. Single Text Embedding
302
- **POST** `/api/predict`
303
-
304
- ```json
305
- {
306
- "data": ["Your text here"]
307
- }
308
- ```
309
-
310
- ### 2. Batch Text Embedding
311
- **POST** `/api/predict`
312
-
313
- ```json
314
- {
315
- "data": [["Text 1", "Text 2", "Text 3"]]
316
- }
317
- ```
318
-
319
- ### 3. Health Check
320
- **GET** `/health`
321
-
322
- Returns: `{"status": "healthy", "model_loaded": true}`
323
-
324
- ## Response Format
325
-
326
- All endpoints return embeddings as JSON arrays of floating-point numbers.
327
- """)
328
-
329
- return interface
330
 
331
  def main():
332
  """Main function to run the application"""
333
- logger.info("Starting Qwen Embedding Model API...")
334
 
335
  # Load model
336
  if not load_model():
337
  logger.error("Failed to load model. Exiting...")
338
  return
339
 
340
- # Create and launch interface
341
- interface = create_interface()
342
 
343
- # Launch with public access
344
- interface.launch(
345
- server_name="0.0.0.0",
346
- server_port=7860,
347
- share=False,
348
- show_error=True,
349
- quiet=False
350
  )
351
 
352
  if __name__ == "__main__":
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  import torch
4
  import numpy as np
5
  from transformers import AutoTokenizer, AutoModel
 
8
  import logging
9
  import os
10
  import time
11
+ import uvicorn
12
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
 
195
  """Health check endpoint"""
196
  return {"status": "healthy", "model_loaded": model is not None}
197
 
198
+ # Create FastAPI application
199
+ app = FastAPI(
200
+ title="Qwen3 Embedding API",
201
+ description="A stable API for generating text embeddings using the Qwen3-Embedding-0.6B model",
202
+ version="1.0.0"
203
+ )
204
+
205
+ # Add CORS middleware
206
+ app.add_middleware(
207
+ CORSMiddleware,
208
+ allow_origins=["*"],
209
+ allow_credentials=True,
210
+ allow_methods=["*"],
211
+ allow_headers=["*"],
212
+ )
213
+
214
+ # FastAPI endpoints
215
+ @app.get("/")
216
+ async def root():
217
+ """Root endpoint with API information"""
218
+ return {
219
+ "message": "Qwen3 Embedding API",
220
+ "version": "1.0.0",
221
+ "model": "Qwen3-Embedding-0.6B",
222
+ "endpoints": {
223
+ "health": "/health",
224
+ "predict": "/api/predict",
225
+ "docs": "/docs"
226
  }
227
+ }
228
+
229
+ @app.get("/health")
230
+ async def health():
231
+ """Health check endpoint"""
232
+ return health_check()
233
+
234
+ @app.post("/api/predict")
235
+ async def predict(data: dict):
236
+ """Main prediction endpoint for embeddings"""
237
+ try:
238
+ if "data" not in data:
239
+ raise HTTPException(status_code=400, detail="Missing 'data' field in request")
240
 
241
+ input_data = data["data"]
 
 
242
 
243
+ # Handle single text or batch texts
244
+ if isinstance(input_data, str):
245
+ # Single text
246
+ embeddings = generate_embeddings(input_data)
247
+ return {"data": [embeddings]}
248
+ elif isinstance(input_data, list):
249
+ if len(input_data) > 0 and isinstance(input_data[0], str):
250
+ # Single text in list
251
+ embeddings = generate_embeddings(input_data[0])
252
+ return {"data": [embeddings]}
253
+ elif len(input_data) > 0 and isinstance(input_data[0], list):
254
+ # Batch texts
255
+ embeddings = generate_embeddings(input_data[0])
256
+ return {"data": [embeddings]}
257
+ else:
258
+ raise HTTPException(status_code=400, detail="Invalid data format")
259
+ else:
260
+ raise HTTPException(status_code=400, detail="Invalid data type")
261
 
262
+ except Exception as e:
263
+ logger.error(f"Error in predict endpoint: {str(e)}")
264
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
265
+
266
+ @app.post("/api/similarity")
267
+ async def similarity(data: dict):
268
+ """Compute similarity between two embeddings"""
269
+ try:
270
+ if "embedding1" not in data or "embedding2" not in data:
271
+ raise HTTPException(status_code=400, detail="Missing embedding1 or embedding2 field")
272
 
273
+ emb1 = data["embedding1"]
274
+ emb2 = data["embedding2"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ if not isinstance(emb1, list) or not isinstance(emb2, list):
277
+ raise HTTPException(status_code=400, detail="Embeddings must be lists")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
+ sim = compute_similarity(emb1, emb2)
280
+ return {"similarity": sim}
281
+
282
+ except Exception as e:
283
+ logger.error(f"Error in similarity endpoint: {str(e)}")
284
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  def main():
287
  """Main function to run the application"""
288
+ logger.info("Starting Qwen3 Embedding Model API...")
289
 
290
  # Load model
291
  if not load_model():
292
  logger.error("Failed to load model. Exiting...")
293
  return
294
 
295
+ logger.info("Model loaded successfully. Starting FastAPI server...")
 
296
 
297
+ # Run with uvicorn
298
+ uvicorn.run(
299
+ app,
300
+ host="0.0.0.0",
301
+ port=7860,
302
+ log_level="info"
 
303
  )
304
 
305
  if __name__ == "__main__":