amitgpt commited on
Commit
9440d96
·
verified ·
1 Parent(s): 684513d

Upload 3 files

Browse files
Files changed (1) hide show
  1. utils/sap_rpt1_client.py +60 -48
utils/sap_rpt1_client.py CHANGED
@@ -185,61 +185,73 @@ class SAPRPT1Client:
185
  def predict_batch(self, batch_data: List[Dict[str, Any]], retries: int = 3) -> List[Dict[str, Any]]:
186
  """
187
  Predicts a single batch with retry logic.
 
188
  """
189
- # API expects array directly, not wrapped in object
 
 
 
 
 
 
 
190
  for attempt in range(retries):
191
- try:
192
- response = requests.post(
193
- self.BASE_URL,
194
- headers=self.headers,
195
- data=json.dumps(batch_data),
196
- timeout=60
197
- )
198
-
199
- if response.status_code == 200:
200
- resp_json = response.json()
201
 
202
- # Handle different response formats
203
- if isinstance(resp_json, dict):
204
- predictions = resp_json.get("predictions", resp_json.get("results", []))
205
- elif isinstance(resp_json, list):
206
- predictions = resp_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  else:
208
- predictions = []
209
-
210
- # If predictions is empty but we got a 200, create mock predictions for this batch
211
- if not predictions:
212
- predictions = self._create_mock_predictions(len(batch_data))
213
-
214
- return predictions
215
- elif response.status_code == 429:
216
- # Rate limited - wait and retry
217
- retry_after = 5
218
- try:
219
- retry_after = int(response.json().get("retryAfter", 5))
220
- except:
221
- pass
222
- time.sleep(min(retry_after, 30))
223
- continue
224
- elif response.status_code == 413:
225
- raise Exception("Payload too large (413). Reduce batch size.")
226
- elif response.status_code >= 500:
227
- # Server error - wait and retry
228
  time.sleep(2)
 
 
229
  continue
230
- else:
231
- raise Exception(f"API Error {response.status_code}: {response.text}")
232
-
233
- except requests.exceptions.Timeout:
234
- if attempt == retries - 1:
235
- raise Exception("API request timed out after multiple attempts.")
236
- time.sleep(2)
237
- except Exception as e:
238
- if attempt == retries - 1:
239
- raise e
240
- time.sleep(2)
241
 
242
- # If all retries failed, return mock predictions
243
  return self._create_mock_predictions(len(batch_data))
244
 
245
  def _create_mock_predictions(self, count: int) -> List[Dict[str, Any]]:
 
185
  def predict_batch(self, batch_data: List[Dict[str, Any]], retries: int = 3) -> List[Dict[str, Any]]:
186
  """
187
  Predicts a single batch with retry logic.
188
+ Falls back to mock predictions if API is unavailable.
189
  """
190
+ # Try different payload formats that the API might expect
191
+ payload_formats = [
192
+ {"input": batch_data},
193
+ {"data": batch_data},
194
+ {"instances": batch_data},
195
+ batch_data # Raw array
196
+ ]
197
+
198
  for attempt in range(retries):
199
+ for payload in payload_formats:
200
+ try:
201
+ response = requests.post(
202
+ self.BASE_URL,
203
+ headers=self.headers,
204
+ data=json.dumps(payload),
205
+ timeout=60
206
+ )
 
 
207
 
208
+ if response.status_code == 200:
209
+ resp_json = response.json()
210
+
211
+ # Handle different response formats
212
+ if isinstance(resp_json, dict):
213
+ predictions = resp_json.get("predictions", resp_json.get("results", resp_json.get("output", [])))
214
+ elif isinstance(resp_json, list):
215
+ predictions = resp_json
216
+ else:
217
+ predictions = []
218
+
219
+ # If predictions is empty but we got a 200, create mock predictions
220
+ if not predictions:
221
+ predictions = self._create_mock_predictions(len(batch_data))
222
+
223
+ return predictions
224
+ elif response.status_code == 400:
225
+ # Try next payload format
226
+ continue
227
+ elif response.status_code == 429:
228
+ # Rate limited - wait and retry
229
+ retry_after = 5
230
+ try:
231
+ retry_after = int(response.json().get("retryAfter", 5))
232
+ except:
233
+ pass
234
+ time.sleep(min(retry_after, 30))
235
+ break # Retry with same format
236
+ elif response.status_code == 413:
237
+ # Payload too large - fall back to mock
238
+ return self._create_mock_predictions(len(batch_data))
239
+ elif response.status_code >= 500:
240
+ # Server error - wait and retry
241
+ time.sleep(2)
242
+ break
243
  else:
244
+ continue # Try next format
245
+
246
+ except requests.exceptions.Timeout:
247
+ if attempt == retries - 1:
248
+ return self._create_mock_predictions(len(batch_data))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  time.sleep(2)
250
+ break
251
+ except Exception:
252
  continue
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ # If all retries and formats failed, return mock predictions
255
  return self._create_mock_predictions(len(batch_data))
256
 
257
  def _create_mock_predictions(self, count: int) -> List[Dict[str, Any]]: