Vaibhav Gaikwad commited on
Commit
5de371d
Β·
1 Parent(s): 0e8ef8a

Revert "gpu to cpu fallback when limit quota exceeds"

Browse files

This reverts commit 0e8ef8a6fd94c23c977a6412ed027f70d54bfb4b.

Files changed (1) hide show
  1. app.py +15 -64
app.py CHANGED
@@ -3,20 +3,15 @@ audiolens β€” app.py
3
  huggingface space backend (zerogpu + gradio native api)
4
 
5
  api endpoints (via gradio):
6
- /call/classify β€” document type classification (dit-base, gpu)
7
- /call/classify_cpu β€” same but cpu-only fallback (no gpu quota needed)
8
- /call/ocr β€” text extraction (easyocr, cpu)
9
- /call/speak β€” text to speech (kokoro, gpu)
10
- /call/health β€” check if space is warm
11
 
12
  the pwa calls these using the gradio js client (@gradio/client)
13
  or via gradio's rest api. each function decorated with @spaces.GPU
14
  gets a gpu allocation only for the duration of that call.
15
 
16
- when gpu quota is exceeded, the pwa falls back to:
17
- - /call/classify_cpu for classification (slower but works)
18
- - browser Web Speech API for tts (no server needed)
19
-
20
  llm extraction (gemini) is called directly from the pwa β€” not here.
21
  """
22
 
@@ -35,17 +30,6 @@ import gradio as gr
35
  from j2_preprocess import preprocess
36
 
37
 
38
- def get_device():
39
- """picks the best available device at call time.
40
- on hf, cuda is only available inside @spaces.GPU functions.
41
- on mac, mps is always available. falls back to cpu."""
42
- if torch.cuda.is_available():
43
- return 'cuda'
44
- if torch.backends.mps.is_available():
45
- return 'mps'
46
- return 'cpu'
47
-
48
-
49
  # ============================================================
50
  # -- dit class mapping --
51
  # ============================================================
@@ -120,9 +104,8 @@ def classify_fn(image):
120
  return {'error': 'no image provided'}
121
 
122
  try:
123
- device = get_device()
124
- dit_model.to(device)
125
- inputs = dit_processor(images=image, return_tensors='pt').to(device)
126
 
127
  with torch.no_grad():
128
  logits = dit_model(**inputs).logits
@@ -139,34 +122,6 @@ def classify_fn(image):
139
  return {'error': str(e)}
140
 
141
 
142
- def classify_cpu_fn(image):
143
- """
144
- cpu-only fallback for classification.
145
- called when gpu quota is exceeded.
146
- same logic as classify_fn but runs entirely on cpu β€” slower but no quota.
147
- called via gradio api: /call/classify_cpu
148
- """
149
- if image is None:
150
- return {'error': 'no image provided'}
151
-
152
- try:
153
- dit_model.to('cpu')
154
- inputs = dit_processor(images=image, return_tensors='pt').to('cpu')
155
-
156
- with torch.no_grad():
157
- logits = dit_model(**inputs).logits
158
-
159
- selected_logits = logits[0, SELECTED_RVL_IDX]
160
- pred_idx = selected_logits.argmax().item()
161
- confidence = torch.softmax(selected_logits, dim=0)[pred_idx].item()
162
- doc_type = DIT_CLASS_MAP[SELECTED_RVL_IDX[pred_idx]]
163
-
164
- return {'doc_type': doc_type, 'confidence': round(confidence, 4)}
165
-
166
- except Exception as e:
167
- return {'error': str(e)}
168
-
169
-
170
  def ocr_gpu(clean_image):
171
  """
172
  runs easyocr on a preprocessed image.
@@ -198,14 +153,21 @@ def ocr_fn(image):
198
  return 'error: no image provided'
199
 
200
  try:
201
- # convert pil to cv2
202
  cv2_image = pil_to_cv2(image)
203
 
204
- # preprocessing β€” easyocr handles its own internally test
 
 
 
 
 
 
205
  # clean = preprocess(cv2_image)
206
 
207
  # ocr inference on cpu
208
  text = ocr_gpu(cv2_image)
 
209
  return text
210
 
211
  except Exception as e:
@@ -312,17 +274,6 @@ with gr.Blocks(title='AudioLens API') as demo:
312
  api_name='health',
313
  )
314
 
315
- # -- cpu fallbacks (hidden, api only β€” used when gpu quota is exceeded) --
316
- classify_cpu_img = gr.Image(type='pil', visible=False)
317
- classify_cpu_out = gr.JSON(visible=False)
318
- classify_cpu_btn = gr.Button('classify_cpu', visible=False)
319
- classify_cpu_btn.click(
320
- fn=classify_cpu_fn,
321
- inputs=classify_cpu_img,
322
- outputs=classify_cpu_out,
323
- api_name='classify_cpu',
324
- )
325
-
326
  gr.Markdown("""
327
  ---
328
  **API endpoints** (use via [@gradio/client](https://www.gradio.app/guides/getting-started-with-the-js-client)):
 
3
  huggingface space backend (zerogpu + gradio native api)
4
 
5
  api endpoints (via gradio):
6
+ /call/classify β€” document type classification (dit-base)
7
+ /call/ocr β€” text extraction (easyocr)
8
+ /call/speak β€” text to speech (kokoro)
9
+ /call/health β€” check if space is warm
 
10
 
11
  the pwa calls these using the gradio js client (@gradio/client)
12
  or via gradio's rest api. each function decorated with @spaces.GPU
13
  gets a gpu allocation only for the duration of that call.
14
 
 
 
 
 
15
  llm extraction (gemini) is called directly from the pwa β€” not here.
16
  """
17
 
 
30
  from j2_preprocess import preprocess
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
33
  # ============================================================
34
  # -- dit class mapping --
35
  # ============================================================
 
104
  return {'error': 'no image provided'}
105
 
106
  try:
107
+ dit_model.to('cuda')
108
+ inputs = dit_processor(images=image, return_tensors='pt').to('cuda')
 
109
 
110
  with torch.no_grad():
111
  logits = dit_model(**inputs).logits
 
122
  return {'error': str(e)}
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def ocr_gpu(clean_image):
126
  """
127
  runs easyocr on a preprocessed image.
 
153
  return 'error: no image provided'
154
 
155
  try:
156
+ # convert pil to cv2 for preprocessing
157
  cv2_image = pil_to_cv2(image)
158
 
159
+ # # preprocessing runs on cpu β€” outside the gpu function
160
+ # clean = preprocess(cv2_image)
161
+
162
+ # # ocr inference on cpu
163
+ # text = ocr_gpu(clean)
164
+
165
+ # trusting easyOCR for test preprocess
166
  # clean = preprocess(cv2_image)
167
 
168
  # ocr inference on cpu
169
  text = ocr_gpu(cv2_image)
170
+
171
  return text
172
 
173
  except Exception as e:
 
274
  api_name='health',
275
  )
276
 
 
 
 
 
 
 
 
 
 
 
 
277
  gr.Markdown("""
278
  ---
279
  **API endpoints** (use via [@gradio/client](https://www.gradio.app/guides/getting-started-with-the-js-client)):