Vaibhav Gaikwad commited on
Commit
0e8ef8a
Β·
1 Parent(s): 021b753

gpu to cpu fallback when limit quota exceeds

Browse files
Files changed (1) hide show
  1. app.py +64 -15
app.py CHANGED
@@ -3,15 +3,20 @@ audiolens β€” app.py
3
  huggingface space backend (zerogpu + gradio native api)
4
 
5
  api endpoints (via gradio):
6
- /call/classify β€” document type classification (dit-base)
7
- /call/ocr β€” text extraction (easyocr)
8
- /call/speak β€” text to speech (kokoro)
9
- /call/health β€” check if space is warm
 
10
 
11
  the pwa calls these using the gradio js client (@gradio/client)
12
  or via gradio's rest api. each function decorated with @spaces.GPU
13
  gets a gpu allocation only for the duration of that call.
14
 
 
 
 
 
15
  llm extraction (gemini) is called directly from the pwa β€” not here.
16
  """
17
 
@@ -30,6 +35,17 @@ import gradio as gr
30
  from j2_preprocess import preprocess
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  # ============================================================
34
  # -- dit class mapping --
35
  # ============================================================
@@ -104,8 +120,9 @@ def classify_fn(image):
104
  return {'error': 'no image provided'}
105
 
106
  try:
107
- dit_model.to('cuda')
108
- inputs = dit_processor(images=image, return_tensors='pt').to('cuda')
 
109
 
110
  with torch.no_grad():
111
  logits = dit_model(**inputs).logits
@@ -122,6 +139,34 @@ def classify_fn(image):
122
  return {'error': str(e)}
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def ocr_gpu(clean_image):
126
  """
127
  runs easyocr on a preprocessed image.
@@ -153,21 +198,14 @@ def ocr_fn(image):
153
  return 'error: no image provided'
154
 
155
  try:
156
- # convert pil to cv2 for preprocessing
157
  cv2_image = pil_to_cv2(image)
158
 
159
- # # preprocessing runs on cpu β€” outside the gpu function
160
- # clean = preprocess(cv2_image)
161
-
162
- # # ocr inference on cpu
163
- # text = ocr_gpu(clean)
164
-
165
- # trusting easyOCR for test preprocess
166
  # clean = preprocess(cv2_image)
167
 
168
  # ocr inference on cpu
169
  text = ocr_gpu(cv2_image)
170
-
171
  return text
172
 
173
  except Exception as e:
@@ -274,6 +312,17 @@ with gr.Blocks(title='AudioLens API') as demo:
274
  api_name='health',
275
  )
276
 
 
 
 
 
 
 
 
 
 
 
 
277
  gr.Markdown("""
278
  ---
279
  **API endpoints** (use via [@gradio/client](https://www.gradio.app/guides/getting-started-with-the-js-client)):
 
3
  huggingface space backend (zerogpu + gradio native api)
4
 
5
  api endpoints (via gradio):
6
+ /call/classify β€” document type classification (dit-base, gpu)
7
+ /call/classify_cpu β€” same but cpu-only fallback (no gpu quota needed)
8
+ /call/ocr β€” text extraction (easyocr, cpu)
9
+ /call/speak β€” text to speech (kokoro, gpu)
10
+ /call/health β€” check if space is warm
11
 
12
  the pwa calls these using the gradio js client (@gradio/client)
13
  or via gradio's rest api. each function decorated with @spaces.GPU
14
  gets a gpu allocation only for the duration of that call.
15
 
16
+ when gpu quota is exceeded, the pwa falls back to:
17
+ - /call/classify_cpu for classification (slower but works)
18
+ - browser Web Speech API for tts (no server needed)
19
+
20
  llm extraction (gemini) is called directly from the pwa β€” not here.
21
  """
22
 
 
35
  from j2_preprocess import preprocess
36
 
37
 
38
+ def get_device():
39
+ """picks the best available device at call time.
40
+ on hf, cuda is only available inside @spaces.GPU functions.
41
+ on mac, mps is always available. falls back to cpu."""
42
+ if torch.cuda.is_available():
43
+ return 'cuda'
44
+ if torch.backends.mps.is_available():
45
+ return 'mps'
46
+ return 'cpu'
47
+
48
+
49
  # ============================================================
50
  # -- dit class mapping --
51
  # ============================================================
 
120
  return {'error': 'no image provided'}
121
 
122
  try:
123
+ device = get_device()
124
+ dit_model.to(device)
125
+ inputs = dit_processor(images=image, return_tensors='pt').to(device)
126
 
127
  with torch.no_grad():
128
  logits = dit_model(**inputs).logits
 
139
  return {'error': str(e)}
140
 
141
 
142
+ def classify_cpu_fn(image):
143
+ """
144
+ cpu-only fallback for classification.
145
+ called when gpu quota is exceeded.
146
+ same logic as classify_fn but runs entirely on cpu β€” slower but no quota.
147
+ called via gradio api: /call/classify_cpu
148
+ """
149
+ if image is None:
150
+ return {'error': 'no image provided'}
151
+
152
+ try:
153
+ dit_model.to('cpu')
154
+ inputs = dit_processor(images=image, return_tensors='pt').to('cpu')
155
+
156
+ with torch.no_grad():
157
+ logits = dit_model(**inputs).logits
158
+
159
+ selected_logits = logits[0, SELECTED_RVL_IDX]
160
+ pred_idx = selected_logits.argmax().item()
161
+ confidence = torch.softmax(selected_logits, dim=0)[pred_idx].item()
162
+ doc_type = DIT_CLASS_MAP[SELECTED_RVL_IDX[pred_idx]]
163
+
164
+ return {'doc_type': doc_type, 'confidence': round(confidence, 4)}
165
+
166
+ except Exception as e:
167
+ return {'error': str(e)}
168
+
169
+
170
  def ocr_gpu(clean_image):
171
  """
172
  runs easyocr on a preprocessed image.
 
198
  return 'error: no image provided'
199
 
200
  try:
201
+ # convert pil to cv2
202
  cv2_image = pil_to_cv2(image)
203
 
204
+ # preprocessing β€” easyocr handles its own internally test
 
 
 
 
 
 
205
  # clean = preprocess(cv2_image)
206
 
207
  # ocr inference on cpu
208
  text = ocr_gpu(cv2_image)
 
209
  return text
210
 
211
  except Exception as e:
 
312
  api_name='health',
313
  )
314
 
315
+ # -- cpu fallbacks (hidden, api only β€” used when gpu quota is exceeded) --
316
+ classify_cpu_img = gr.Image(type='pil', visible=False)
317
+ classify_cpu_out = gr.JSON(visible=False)
318
+ classify_cpu_btn = gr.Button('classify_cpu', visible=False)
319
+ classify_cpu_btn.click(
320
+ fn=classify_cpu_fn,
321
+ inputs=classify_cpu_img,
322
+ outputs=classify_cpu_out,
323
+ api_name='classify_cpu',
324
+ )
325
+
326
  gr.Markdown("""
327
  ---
328
  **API endpoints** (use via [@gradio/client](https://www.gradio.app/guides/getting-started-with-the-js-client)):