Xinqi04 commited on
Commit
2ce1c60
·
1 Parent(s): 681582c
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -20,8 +20,17 @@ class PredictionInput(BaseModel):
20
  @app.post("/predict")
21
  def predict(input_data: PredictionInput):
22
  X = np.array(input_data.data).reshape(1, -1)
23
- prediction = model.predict(X)
24
- return {"prediction": prediction.tolist()}
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  # Tambahkan baris ini jika kamu menjalankan app.py langsung!
 
20
  @app.post("/predict")
21
  def predict(input_data: PredictionInput):
22
  X = np.array(input_data.data).reshape(1, -1)
23
+ prediction = model.predict(X)[0] # ambil nilai scalar
24
+ labels = {
25
+ 0: "Bipolar Type 1",
26
+ 1: "Bipolar Type 2",
27
+ 2: "Depression",
28
+ 3: "Normal"
29
+ }
30
+ return {
31
+ "prediction": int(prediction),
32
+ "label": labels.get(int(prediction), "Unknown")
33
+ }
34
 
35
 
36
  # Tambahkan baris ini jika kamu menjalankan app.py langsung!