File size: 8,735 Bytes
db8243f
 
 
 
 
 
d574c84
b2e8e2a
db8243f
 
1d7e89f
9a6bc8b
27cedc0
db8243f
2bfc17a
6e07918
db8243f
 
 
 
 
e710ae2
db8243f
 
 
 
 
82b91c3
 
 
 
 
 
 
db8243f
9a6bc8b
8ce5b5b
 
9a6bc8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d574c84
9a6bc8b
27cedc0
 
 
db8243f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27cedc0
 
1228c06
 
 
 
27cedc0
1228c06
 
 
 
27cedc0
1228c06
27cedc0
 
 
1228c06
 
79d0f4a
9a6bc8b
26ad899
 
 
 
 
d574c84
 
64afff6
 
 
 
 
 
 
 
 
 
 
4390bf9
64afff6
 
 
 
 
 
 
 
 
 
 
 
4390bf9
64afff6
 
 
4390bf9
64afff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b41ff88
64afff6
 
 
4390bf9
 
db8243f
 
 
82b91c3
db8243f
82b91c3
db8243f
 
 
82b91c3
 
db8243f
 
 
 
 
 
 
 
82b91c3
db8243f
 
 
 
 
 
 
82b91c3
db8243f
 
 
 
 
 
 
82b91c3
db8243f
 
 
 
 
 
 
 
e710ae2
 
 
 
 
 
 
 
 
 
 
 
9a6bc8b
6e07918
9a6bc8b
1d7e89f
6a572de
 
 
9a6bc8b
 
 
 
 
 
 
e710ae2
 
 
 
 
 
 
 
db8243f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
from typing import Union
import os
import requests
import json
import time
from datetime import datetime
import time
import pandas as pd
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi import Query
from transformers import pipeline

from helper import generate_random_predictions, get_sample_similarity_attr, process_api_response

app = FastAPI()

# Configure CORS settings
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins in development
    allow_credentials=True,
    allow_methods=["*"],  # Allow all HTTP methods
    allow_headers=["*"],  # Allow all headers
)

# Replace these variables with your Databricks workspace information
DATABRICKS_INSTANCE = os.getenv('DATABRICKS_INSTANCE')
API_TOKEN = os.getenv('API_TOKEN')
TASK_RUNID = "1054089068841244"

# from dotenv import load_dotenv, find_dotenv
# _ = load_dotenv(find_dotenv()) # read local .env file

class PredictionInput(BaseModel):
    id:str
    isMinimized: bool
    country: str
    category: str
    basecode: str
    scenario: str
    weekDate: str
    packGroup: str
    productRange: str
    baseNumberInMultipack: str
    segment: str
    superSegment: str
    salty: str
    choco: str
    flavor: str
    levelOfSugar: str
    listPricePerUnitMl: float
    weightPerUnitMl: float
    sampleOutput: bool

class inputtext(BaseModel):
    inputtext:str

@app.get("/")
def read_root():
    return {"Hello": "World"}

@app.get("/get_prediction")
def get_prediction_from_jobrun():
    #Add the documentation
    """
    Get the prediction from the Databricks job run
    """
    url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/runs/get-output"
    headers = {
        'Authorization': f'Bearer {API_TOKEN}',
        'Content-Type': 'application/json'
    }
    data = {
        "run_id": TASK_RUNID
    }

    response = requests.get(url, headers=headers, data=json.dumps(data))

    if response.status_code == 200:
        print("Pipeline run initiated successfully.")
        output_json = json.loads(response.json()['notebook_output']['result'])
        nb_output = output_json['prediction']
        return nb_output
    else:
        print(response)
        print("Failed to initiate pipeline run.")
        print("Status Code:", response.status_code)
        return response.text

classifier = pipeline("sentiment-analysis")  # Defaults to distilbert-base-uncased-finetuned-sst-2-english

# Define input schema
class InputText(BaseModel):
    text: str  # Expect JSON request body with a "text" field

@app.post("/get_sentiment")
def get_sentiment_details(input: InputText):  
    text = input.text  # Extract the actual string from the Pydantic model

    print(f"===== The type of the text is : {type(text)} =====")  # Debugging output

    result = classifier(text)  # Pass only the extracted string
    label = result[0]['label']
    score = result[0]['score']

    return {"sentiment": label, "score": score}
    
@app.post("/get_prediction_on_userinput")
def run_pred_pipeline(input: PredictionInput):

    print(f"Here is the input dict : {input.dict()}")
    print(f"Running the pipeline : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ")

    ## Hardcoding for testing purposes ##
    if input.dict().get('sampleOutput') == True:
        time.sleep(4)
        temp_predictions_dict = generate_random_predictions()
        sample_sim_attr = get_sample_similarity_attr()
        data_out = {
            "status" : "success",
            "data" : {
                "id": input.dict()['id'],
                "predictions": temp_predictions_dict,
                "similarity": sample_sim_attr
            }
        }
        return data_out
    
    else:
    
        headers = {
            "Authorization": f"Bearer {API_TOKEN}",
            "Content-Type": "application/json"
        }
        # Pipeline details
        pipeline_id = "403360183892362"
        payload = {
            'job_id': pipeline_id,
            'notebook_params': input.dict()
        }

        # Trigger the run
        api_url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/run-now"
        response = requests.post(api_url, headers=headers, data=json.dumps(payload))
        response_json = response.json()
        print(f"\nPrediction pipeline started with details : {response_json}\n")
        run_id = response_json["run_id"]
        #pred_out = pd.DataFrame()
        while True:
            time.sleep(2)
            api_url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/runs/get?run_id={run_id}"
            response = requests.get(api_url, headers=headers)
            response_json = response.json()
            task_run_id = response_json['tasks'][0]['run_id']
            run_status = response_json["state"]["life_cycle_state"]
            print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} Status : {run_status}")
            job_status = response_json["state"].get('result_state')
            if job_status == 'SUCCESS':
                api_url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/runs/get-output"
                payload = dict(run_id=task_run_id)
                response = requests.get(api_url, headers=headers, data=json.dumps(payload))
                output_json = json.loads(response.json()['notebook_output']['result'])
                temp_predictions_dict, sample_sim_attr = process_api_response(output_json)
                data_out = {
                            "status" : "success",
                            "data" : {
                                "id": input.dict()['id'],
                                "predictions": temp_predictions_dict,
                                "similarity": sample_sim_attr
                            }
                        }
                break;
            
        return data_out


@app.get("/get_prediction_from_databricks")
def run_xpipeline():
    print(f"Running the pipeline : {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ")
    
    headers = {
        "Authorization": f"Bearer {API_TOKEN}",
        "Content-Type": "application/json"
    }
    # Pipeline details
    pipeline_id = "413640122908266"
    json_data = None 
    payload = {
        'job_id': pipeline_id,
        'notebook_params': {
            'data': json_data  # Send data as a JSON string
        }
    }

    # Trigger the run
    api_url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/run-now"
    response = requests.post(api_url, headers=headers, data=json.dumps(payload))
    response_json = response.json()
    print(f"\nPrediction pipeline started with details : {response_json}\n")
    run_id = response_json["run_id"]
    #pred_out = pd.DataFrame()
    while True:
        time.sleep(2)
        api_url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/runs/get?run_id={run_id}"
        response = requests.get(api_url, headers=headers)
        response_json = response.json()
        task_run_id = response_json['tasks'][0]['run_id']
        run_status = response_json["state"]["life_cycle_state"]
        print(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} Status : {run_status}")
        job_status = response_json["state"].get('result_state')
        if job_status == 'SUCCESS':
            api_url = f"{DATABRICKS_INSTANCE}/api/2.1/jobs/runs/get-output"
            payload = dict(run_id=task_run_id)
            response = requests.get(api_url, headers=headers, data=json.dumps(payload))
            output_json = json.loads(response.json()['notebook_output']['result'])
            nb_output = output_json['prediction']
            break;

    return nb_output


class QueryRequest(BaseModel):
    query: str

@app.post("/query_ai")
async def query_ai(request: QueryRequest):
    try:
        # TODO: Implement actual AI processing here
        # For now, return sample product data
        return {
            "status": "success",
            "data": {
                "baseCode": "GB10002",
                "scenario": "SAMPLE_EUCO_Scenario",
                "weekDate": "2025-04-28",
                "levelOfSugar": "STANDARD",
                "packGroup": "EVERYDAY BLOCK",
                "productRange": "GREEN & BLACKS",
                "segment": "CHOC BLOCK",
                "superSegment": "STANDARD CHOCOLATE",
                "baseNumberInMultipack": "SINGLE",
                "flavor": "CITRUS",
                "choco": "MILK",
                "salty": "NO",
                "weightPerUnitMl": 0.28,
                "listPricePerUnitMl": 1.75
            }
        }
    except Exception as e:
        return {
            "status": "error",
            "error": str(e)
        }

@app.get("/items/{item_id}")
def read_item(item_id: int, q: Union[str, None] = None):
    return {"item_id": item_id, "q": q}