Spaces:
Sleeping
Sleeping
Commit
·
64afff6
1
Parent(s):
1461675
Add sample output handling in prediction pipeline and API submission
Browse files
app.py
CHANGED
|
@@ -112,87 +112,65 @@ def run_pred_pipeline(input: PredictionInput):
|
|
| 112 |
print(f"Running the pipeline : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ")
|
| 113 |
|
| 114 |
## Hardcoding for testing purposes ##
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
# return data_out
|
| 128 |
-
|
| 129 |
-
print(f"Here is the input dict : {input.dict()}")
|
| 130 |
-
print(f"Running the pipeline : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ")
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
# "category_mdlz": "EUCO",
|
| 145 |
-
# "basecode": "GB10002",
|
| 146 |
-
# "scenario": "sc_1",
|
| 147 |
-
# "week_date": "2025-04-28",
|
| 148 |
-
# "level_of_sugar": "STANDARD",
|
| 149 |
-
# "pack_group": "CHOC ADULT SGLS",
|
| 150 |
-
# "product_range": "MILKA",
|
| 151 |
-
# "segment": "CHOC SGLS",
|
| 152 |
-
# "supersegment": "STANDARD CHOCOLATE",
|
| 153 |
-
# "base_number_in_multipack": "SINGLE",
|
| 154 |
-
# "flavour": "CITRUS",
|
| 155 |
-
# "choco": "MILK",
|
| 156 |
-
# "salty": "NO",
|
| 157 |
-
# "weight_per_unit_mdlz": "0.28",
|
| 158 |
-
# "list_price_per_unit_mdlz": "1.75"
|
| 159 |
-
# }
|
| 160 |
-
}
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
response_json = response.json()
|
| 166 |
-
print(f"\nPrediction pipeline started with details : {response_json}\n")
|
| 167 |
-
run_id = response_json["run_id"]
|
| 168 |
-
#pred_out = pd.DataFrame()
|
| 169 |
-
while True:
|
| 170 |
-
time.sleep(2)
|
| 171 |
-
api_url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/runs/get?run_id={run_id}"
|
| 172 |
-
response = requests.get(api_url, headers=headers)
|
| 173 |
response_json = response.json()
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
api_url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/runs/get
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
}
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
return data_out
|
| 196 |
|
| 197 |
|
| 198 |
@app.get("/get_prediction_from_databricks")
|
|
|
|
| 112 |
print(f"Running the pipeline : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ")
|
| 113 |
|
| 114 |
## Hardcoding for testing purposes ##
|
| 115 |
+
if input.dict().get('sampleOutput') == 'true':
|
| 116 |
+
temp_predictions_dict = generate_random_predictions()
|
| 117 |
+
sample_sim_attr = get_sample_similarity_attr()
|
| 118 |
+
data_out = {
|
| 119 |
+
"status" : "success",
|
| 120 |
+
"data" : {
|
| 121 |
+
"id": input.dict()['id'],
|
| 122 |
+
"predictions": temp_predictions_dict,
|
| 123 |
+
"similarity": sample_sim_attr
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
return data_out
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
else:
|
| 129 |
+
|
| 130 |
+
headers = {
|
| 131 |
+
"Authorization": f"Bearer {API_TOKEN}",
|
| 132 |
+
"Content-Type": "application/json"
|
| 133 |
+
}
|
| 134 |
+
# Pipeline details
|
| 135 |
+
pipeline_id = "403360183892362"
|
| 136 |
+
payload = {
|
| 137 |
+
'job_id': pipeline_id,
|
| 138 |
+
'notebook_params': input.dict()
|
| 139 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
# Trigger the run
|
| 142 |
+
api_url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/run-now"
|
| 143 |
+
response = requests.post(api_url, headers=headers, data=json.dumps(payload))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
response_json = response.json()
|
| 145 |
+
print(f"\nPrediction pipeline started with details : {response_json}\n")
|
| 146 |
+
run_id = response_json["run_id"]
|
| 147 |
+
#pred_out = pd.DataFrame()
|
| 148 |
+
while True:
|
| 149 |
+
time.sleep(2)
|
| 150 |
+
api_url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/runs/get?run_id={run_id}"
|
| 151 |
+
response = requests.get(api_url, headers=headers)
|
| 152 |
+
response_json = response.json()
|
| 153 |
+
task_run_id = response_json['tasks'][0]['run_id']
|
| 154 |
+
run_status = response_json["state"]["life_cycle_state"]
|
| 155 |
+
print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} Status : {run_status}")
|
| 156 |
+
job_status = response_json["state"].get('result_state')
|
| 157 |
+
if job_status == 'SUCCESS':
|
| 158 |
+
api_url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/runs/get-output"
|
| 159 |
+
payload = dict(run_id=task_run_id)
|
| 160 |
+
response = requests.get(api_url, headers=headers, data=json.dumps(payload))
|
| 161 |
+
output_json = json.loads(response.json()['notebook_output']['result'])
|
| 162 |
+
temp_predictions_dict, sample_sim_attr = process_api_response(output_json)
|
| 163 |
+
data_out = {
|
| 164 |
+
"status" : "success",
|
| 165 |
+
"data" : {
|
| 166 |
+
"id": input.dict()['id'],
|
| 167 |
+
"predictions": temp_predictions_dict,
|
| 168 |
+
"similarity": sample_sim_attr
|
| 169 |
+
}
|
| 170 |
}
|
| 171 |
+
break;
|
| 172 |
+
|
| 173 |
+
return data_out
|
|
|
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
@app.get("/get_prediction_from_databricks")
|