Baktabek commited on
Commit
b2adce0
·
verified ·
1 Parent(s): 1a98a8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -11
app.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  HuggingFace Space: Jina Embeddings v3 API
3
  Free embedding service for AI-RAG-Core project
 
4
  """
5
 
6
  from fastapi import FastAPI, HTTPException
@@ -9,25 +10,46 @@ from typing import List
9
  import torch
10
  from transformers import AutoModel
11
  import logging
 
 
12
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
16
- app = FastAPI(title="Jina Embeddings v3 API", version="1.0.0")
17
 
18
  # Load model on startup
19
  model = None
 
 
 
 
 
 
20
 
21
  @app.on_event("startup")
22
  async def load_model():
23
- global model
24
  logger.info("Loading jina-embeddings-v3 model...")
 
 
 
 
 
 
 
 
 
25
  model = AutoModel.from_pretrained(
26
  'jinaai/jina-embeddings-v3',
27
  trust_remote_code=True,
28
  device_map="auto"
29
  )
30
- logger.info("Model loaded successfully!")
 
 
 
 
31
 
32
 
33
  class EmbeddingRequest(BaseModel):
@@ -47,13 +69,29 @@ async def create_embeddings(request: EmbeddingRequest):
47
  if model is None:
48
  raise HTTPException(status_code=503, detail="Model not loaded")
49
 
50
- try:
51
- # Generate embeddings
52
- embeddings = model.encode(
53
- request.input,
54
- task=request.task,
55
- batch_size=32
56
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Convert to list format
59
  if isinstance(embeddings, torch.Tensor):
@@ -67,20 +105,44 @@ async def create_embeddings(request: EmbeddingRequest):
67
  for i, emb in enumerate(embeddings)
68
  ]
69
 
 
 
 
 
 
 
 
70
  return EmbeddingResponse(data=data)
71
 
72
  except Exception as e:
73
  logger.error(f"Embedding generation failed: {e}")
 
 
 
 
 
 
74
  raise HTTPException(status_code=500, detail=str(e))
75
 
76
 
77
  @app.get("/health")
78
  async def health_check():
79
  """Health check endpoint"""
 
 
 
 
 
 
 
 
80
  return {
81
  "status": "healthy",
82
  "model": "jina-embeddings-v3",
83
- "model_loaded": model is not None
 
 
 
84
  }
85
 
86
 
@@ -89,9 +151,22 @@ async def root():
89
  """Root endpoint"""
90
  return {
91
  "service": "Jina Embeddings v3 API",
92
- "version": "1.0.0",
93
  "endpoints": {
94
  "embeddings": "/embeddings (POST)",
95
  "health": "/health (GET)"
 
 
 
 
96
  }
97
  }
 
 
 
 
 
 
 
 
 
 
1
  """
2
  HuggingFace Space: Jina Embeddings v3 API
3
  Free embedding service for AI-RAG-Core project
4
+ FIXED VERSION with memory management and batch limits
5
  """
6
 
7
  from fastapi import FastAPI, HTTPException
 
10
  import torch
11
  from transformers import AutoModel
12
  import logging
13
+ import gc
14
+ import asyncio
15
 
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
+ app = FastAPI(title="Jina Embeddings v3 API", version="1.0.1")
20
 
21
  # Load model on startup
22
  model = None
23
+ device = None
24
+
25
+ # Configuration
26
+ MAX_BATCH_SIZE = 50 # Limit batch size to prevent OOM
27
+ MAX_TEXT_LENGTH = 8192 # Jina v3 max tokens
28
+
29
 
30
  @app.on_event("startup")
31
  async def load_model():
32
+ global model, device
33
  logger.info("Loading jina-embeddings-v3 model...")
34
+
35
+ # Detect device
36
+ if torch.cuda.is_available():
37
+ device = "cuda"
38
+ logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
39
+ else:
40
+ device = "cpu"
41
+ logger.info("Using CPU")
42
+
43
  model = AutoModel.from_pretrained(
44
  'jinaai/jina-embeddings-v3',
45
  trust_remote_code=True,
46
  device_map="auto"
47
  )
48
+
49
+ # Set to eval mode to save memory
50
+ model.eval()
51
+
52
+ logger.info(f"Model loaded successfully on {device}!")
53
 
54
 
55
  class EmbeddingRequest(BaseModel):
 
69
  if model is None:
70
  raise HTTPException(status_code=503, detail="Model not loaded")
71
 
72
+ # Validate batch size
73
+ if len(request.input) > MAX_BATCH_SIZE:
74
+ raise HTTPException(
75
+ status_code=400,
76
+ detail=f"Batch size {len(request.input)} exceeds limit {MAX_BATCH_SIZE}"
 
77
  )
78
+
79
+ # Validate text length
80
+ for text in request.input:
81
+ if len(text) > MAX_TEXT_LENGTH:
82
+ raise HTTPException(
83
+ status_code=400,
84
+ detail=f"Text length exceeds {MAX_TEXT_LENGTH} characters"
85
+ )
86
+
87
+ try:
88
+ # Generate embeddings with no_grad to save memory
89
+ with torch.no_grad():
90
+ embeddings = model.encode(
91
+ request.input,
92
+ task=request.task,
93
+ batch_size=16 # Process in smaller chunks
94
+ )
95
 
96
  # Convert to list format
97
  if isinstance(embeddings, torch.Tensor):
 
105
  for i, emb in enumerate(embeddings)
106
  ]
107
 
108
+ # CRITICAL: Clear GPU cache after each request
109
+ if device == "cuda":
110
+ torch.cuda.empty_cache()
111
+
112
+ # Force garbage collection
113
+ gc.collect()
114
+
115
  return EmbeddingResponse(data=data)
116
 
117
  except Exception as e:
118
  logger.error(f"Embedding generation failed: {e}")
119
+
120
+ # Clear cache on error
121
+ if device == "cuda":
122
+ torch.cuda.empty_cache()
123
+ gc.collect()
124
+
125
  raise HTTPException(status_code=500, detail=str(e))
126
 
127
 
128
  @app.get("/health")
129
  async def health_check():
130
  """Health check endpoint"""
131
+ memory_info = {}
132
+
133
+ if torch.cuda.is_available():
134
+ memory_info = {
135
+ "gpu_memory_allocated": f"{torch.cuda.memory_allocated() / 1024**2:.2f} MB",
136
+ "gpu_memory_reserved": f"{torch.cuda.memory_reserved() / 1024**2:.2f} MB"
137
+ }
138
+
139
  return {
140
  "status": "healthy",
141
  "model": "jina-embeddings-v3",
142
+ "model_loaded": model is not None,
143
+ "device": device,
144
+ "max_batch_size": MAX_BATCH_SIZE,
145
+ **memory_info
146
  }
147
 
148
 
 
151
  """Root endpoint"""
152
  return {
153
  "service": "Jina Embeddings v3 API",
154
+ "version": "1.0.1",
155
  "endpoints": {
156
  "embeddings": "/embeddings (POST)",
157
  "health": "/health (GET)"
158
+ },
159
+ "limits": {
160
+ "max_batch_size": MAX_BATCH_SIZE,
161
+ "max_text_length": MAX_TEXT_LENGTH
162
  }
163
  }
164
+
165
+
166
+ @app.post("/clear_cache")
167
+ async def clear_cache():
168
+ """Manually clear GPU cache (admin endpoint)"""
169
+ if device == "cuda":
170
+ torch.cuda.empty_cache()
171
+ gc.collect()
172
+ return {"status": "cache cleared"}