Dinaliah Xaviant commited on
Commit
e07fa3a
·
verified ·
1 Parent(s): 52718ad

Update app.py (#1)

Browse files

- Update app.py (2ef87e7f2d4f3cc07fabc8e6ed293644140e8eeb)


Co-authored-by: Daffa Dians Ramadhan <Xaviant@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +97 -91
app.py CHANGED
@@ -4,103 +4,109 @@ import numpy as np
4
  from PIL import Image
5
  import io
6
  import sys
7
- import gradio as gr
8
- from tensorflow.keras.applications.resnet50 import preprocess_input
9
 
10
- # =========================
11
- # 1. FastAPI Init
12
- # =========================
13
- app = FastAPI(title="Ashoka Buried Penis Classifier API")
14
 
15
- # =========================
16
  # 2. Load Model
17
- # =========================
18
- print("Loading model...")
19
  try:
20
- model = tf.keras.models.load_model("cnn_kfold_best_model.h5")
21
- print("Model loaded successfully")
22
  except Exception as e:
23
- print("Failed to load model:", e)
24
- sys.exit(1)
25
-
26
- class_names = ["Normal", "Buried"]
27
-
28
- # =========================
29
- # 3. Preprocessing
30
- # =========================
31
- def prepare_image(image: Image.Image):
32
- image = image.convert("RGB")
33
- image = image.resize((224, 224))
34
- img_array = np.array(image)
35
- img_array = np.expand_dims(img_array, axis=0)
36
- img_array = preprocess_input(img_array)
37
- return img_array
38
-
39
- # =========================
40
- # 4. Prediction Logic
41
- # =========================
42
- def predict_image(image):
43
- if image is None:
44
- return "No image uploaded", 0.0, 0.0
45
-
46
- processed = prepare_image(image)
47
- prediction = model.predict(processed)[0][0]
48
-
49
- prob_buried = float(prediction * 100)
50
- prob_normal = float((1 - prediction) * 100)
51
-
52
- label = "Buried Penis" if prediction > 0.5 else "Normal"
53
-
54
- return label, round(prob_normal, 2), round(prob_buried, 2)
55
-
56
- # =========================
57
- # 5. FastAPI Endpoint
58
- # =========================
 
 
59
  @app.post("/predict")
60
- async def api_predict(file: UploadFile = File(...)):
61
- image_bytes = await file.read()
62
- image = Image.open(io.BytesIO(image_bytes))
63
- label, normal, buried = predict_image(image)
64
-
65
- return {
66
- "class": label,
67
- "probabilities": {
68
- "normal": normal,
69
- "buried": buried
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  }
72
 
73
- # =========================
74
- # 6. Gradio UI
75
- # =========================
76
- with gr.Blocks() as demo:
77
- gr.Markdown("""
78
- # Ashoka Hipospadia Classifier API - ResNet50
79
- **Medical screening tool for Buried Penis**
80
-
81
- ⚠️ This tool is **NOT a diagnostic device**.
82
- Results must be interpreted by **qualified medical professionals**.
83
- """)
84
-
85
- with gr.Row():
86
- image_input = gr.Image(
87
- type="pil",
88
- label="Upload / Drag & Drop Medical Image"
89
- )
90
-
91
- classify_btn = gr.Button("Analyze Image")
92
-
93
- result_label = gr.Textbox(label="Prediction Result")
94
- prob_normal = gr.Number(label="Normal Probability (%)")
95
- prob_buried = gr.Number(label="Buried Probability (%)")
96
-
97
- classify_btn.click(
98
- fn=predict_image,
99
- inputs=image_input,
100
- outputs=[result_label, prob_normal, prob_buried]
101
- )
102
-
103
- # =========================
104
- # 7. Mount Gradio to FastAPI
105
- # =========================
106
- app = gr.mount_gradio_app(app, demo, path="/")
 
4
  from PIL import Image
5
  import io
6
  import sys
7
+ from tensorflow.keras.applications.densenet import preprocess_input
 
8
 
9
+ # 1. Inisialisasi Aplikasi
10
+ app = FastAPI(title="Ashoka Hipospadia Classifier API")
 
 
11
 
 
12
  # 2. Load Model
13
+ print("Sedang memuat model...")
 
14
  try:
15
+ model = tf.keras.models.load_model('cnn_kfold_best_model_v2.h5')
16
+ print("Model berhasil dimuat!")
17
  except Exception as e:
18
+ print(f"Error memuat model: {e}")
19
+ sys.exit(1) # Matikan server jika model gagal load
20
+
21
+ # Label kelas: 0 = normal, 1 = buried
22
+ class_names = ['normal', 'buried']
23
+
24
+ # 3. Fungsi Preprocessing
25
+ def prepare_image(image_bytes):
26
+ """
27
+ Preprocessing gambar untuk model DenseNet
28
+ - Konversi ke RGB (3 channel)
29
+ - Resize ke 224x224
30
+ - Preprocessing DenseNet
31
+ """
32
+ try:
33
+ img = Image.open(io.BytesIO(image_bytes))
34
+
35
+ # Paksa ubah ke RGB agar PNG transparan tidak error
36
+ img = img.convert("RGB")
37
+
38
+ # Resize ke ukuran input model (224x224 untuk ResNet50)
39
+ img = img.resize((224, 224))
40
+
41
+ # Convert ke numpy array
42
+ img_array = np.array(img)
43
+
44
+ # Tambah batch dimension
45
+ img_array = np.expand_dims(img_array, axis=0)
46
+
47
+ # Preprocessing ResNet50 (HARUS sama dengan training!)
48
+ img_array = preprocess_input(img_array)
49
+
50
+ return img_array
51
+ except Exception as e:
52
+ print(f"Error saat memproses gambar: {e}")
53
+ return None
54
+
55
+ # 4. Endpoint Prediksi
56
  @app.post("/predict")
57
+ async def predict(file: UploadFile = File(...)):
58
+ """
59
+ Endpoint untuk prediksi gambar
60
+ Input: File gambar (JPG, PNG, BMP)
61
+ Output: JSON dengan class dan confidence
62
+ """
63
+ try:
64
+ # Baca file gambar
65
+ image_bytes = await file.read()
66
+
67
+ # Proses gambar
68
+ processed_image = prepare_image(image_bytes)
69
+
70
+ if processed_image is None:
71
+ raise HTTPException(status_code=400, detail="File bukan gambar yang valid")
72
+
73
+ # Prediksi
74
+ prediction = model.predict(processed_image)
75
+ pred_value = float(prediction[0][0])
76
+
77
+ # Hitung probabilitas
78
+ # Model output: 0 = normal, 1 = buried
79
+ prob_normal = (1 - pred_value) * 100
80
+ prob_buried = pred_value * 100
81
+
82
+ # Tentukan kelas berdasarkan threshold 0.5
83
+ top_class_idx = 1 if pred_value > 0.5 else 0
84
+
85
+ # Hasil dalam format JSON
86
+ result = {
87
+ "class": class_names[top_class_idx],
88
+ "confidence": float(max(prob_normal, prob_buried)),
89
+ "probabilities": {
90
+ "normal": float(prob_normal),
91
+ "buried": float(prob_buried)
92
+ }
93
  }
94
+ return result
95
+
96
+ except Exception as e:
97
+ # Cetak error ke log
98
+ print(f"CRITICAL ERROR: {e}")
99
+ raise HTTPException(status_code=500, detail=str(e))
100
+
101
+ # 5. Endpoint Home
102
+ @app.get("/")
103
+ def home():
104
+ """Endpoint root untuk testing API"""
105
+ return {
106
+ "message": "Ashoka Hipospadia Classifier API Online! 🚀\n",
107
+ "model": "DenseNet Binary Classification\n",
108
+ "classes": class_names
109
  }
110
 
111
+ # API siap digunakan dengan uvicorn
112
+ # Jalankan dengan: uvicorn app:app --host 0.0.0.0 --port 7860