kodetr commited on
Commit
589f061
·
verified ·
1 Parent(s): a983e8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -46
app.py CHANGED
@@ -2,72 +2,139 @@ import tensorflow as tf
2
  import gradio as gr
3
  import numpy as np
4
  from PIL import Image
5
- from fastapi import FastAPI, UploadFile, File
6
  import io
7
-
8
- gpus = tf.config.list_physical_devices('GPU')
9
- if gpus:
10
- tf.config.experimental.set_memory_growth(gpus[0], True)
11
 
12
  IMG_SIZE = 224
13
 
14
- # ===== LOAD MODEL =====
15
- model = tf.keras.models.load_model("model.keras", compile=False)
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  DR_CLASSES = ["No DR","Mild","Moderate","Severe","Proliferative DR"]
18
  DME_CLASSES = ["No DME","Low Risk","High Risk"]
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # ===== PREPROCESS =====
21
  def preprocess(img):
22
  img = img.resize((IMG_SIZE, IMG_SIZE))
 
 
 
23
  arr = np.array(img) / 255.0
24
- return np.expand_dims(arr,0).astype(np.float32)
25
 
26
- # ===== CORE PREDICT =====
27
  def core_predict(img):
28
- x = preprocess(img)
29
- preds = model.predict(x)
30
-
31
- if isinstance(preds, dict):
32
- dr = preds["dr_head"][0]
33
- dme = preds["dme_head"][0]
34
- else:
35
- dr = preds[0][0]
36
- dme = preds[1][0]
37
-
38
- dr = tf.nn.softmax(dr).numpy()
39
- dme = tf.nn.softmax(dme).numpy()
40
-
41
- return {
42
- "dr": {
43
- "label": DR_CLASSES[int(np.argmax(dr))],
44
- "confidence": float(np.max(dr)*100)
45
- },
46
- "dme": {
47
- "label": DME_CLASSES[int(np.argmax(dme))],
48
- "confidence": float(np.max(dme)*100)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  }
50
- }
51
-
52
- # ===== FASTAPI =====
53
- app = FastAPI()
54
-
55
- @app.post("/predict")
56
- async def api_predict(file: UploadFile = File(...)):
57
- image = await file.read()
58
- img = Image.open(io.BytesIO(image)).convert("RGB")
59
- return core_predict(img)
60
 
61
- # ===== GRADIO =====
62
  def gradio_predict(img):
 
 
63
  return core_predict(img)
64
 
 
65
  demo = gr.Interface(
66
  fn=gradio_predict,
67
- inputs=gr.Image(type="pil"),
68
- outputs="json",
69
- title="DR & DME Detection"
 
 
 
 
 
 
70
  )
71
 
72
- # Mount Gradio ke FastAPI
73
- app = gr.mount_gradio_app(app, demo, path="/")
 
 
 
 
 
2
  import gradio as gr
3
  import numpy as np
4
  from PIL import Image
 
5
  import io
6
+ import os
 
 
 
7
 
8
  IMG_SIZE = 224
9
 
10
+ # ===== DETECT GPU/CPU =====
11
+ try:
12
+ gpus = tf.config.list_physical_devices('GPU')
13
+ if gpus:
14
+ # Coba atur memory growth, tapi jangan crash jika gagal
15
+ try:
16
+ tf.config.experimental.set_memory_growth(gpus[0], True)
17
+ except:
18
+ pass
19
+ print("GPU available")
20
+ else:
21
+ print("Using CPU")
22
+ except:
23
+ print("GPU configuration skipped")
24
 
25
  DR_CLASSES = ["No DR","Mild","Moderate","Severe","Proliferative DR"]
26
  DME_CLASSES = ["No DME","Low Risk","High Risk"]
27
 
28
+ # ===== LOAD MODEL (with error handling) =====
29
+ MODEL_PATH = "model.keras"
30
+
31
+ # Cek apakah model file ada
32
+ if not os.path.exists(MODEL_PATH):
33
+ # Fallback untuk demo jika model tidak ada
34
+ print(f"Warning: {MODEL_PATH} not found. Using mock predictions.")
35
+ model = None
36
+ else:
37
+ try:
38
+ # Load dengan opsi yang lebih kompatibel
39
+ model = tf.keras.models.load_model(
40
+ MODEL_PATH,
41
+ compile=False,
42
+ safe_mode=False # Untuk compatibility
43
+ )
44
+ print("Model loaded successfully")
45
+ except Exception as e:
46
+ print(f"Error loading model: {e}")
47
+ model = None
48
+
49
  # ===== PREPROCESS =====
50
  def preprocess(img):
51
  img = img.resize((IMG_SIZE, IMG_SIZE))
52
+ # Pastikan 3 channel
53
+ if img.mode != 'RGB':
54
+ img = img.convert('RGB')
55
  arr = np.array(img) / 255.0
56
+ return np.expand_dims(arr, 0).astype(np.float32)
57
 
58
+ # ===== CORE PREDICT (with fallback) =====
59
  def core_predict(img):
60
+ # Jika model tidak ada, return mock predictions untuk demo
61
+ if model is None:
62
+ return {
63
+ "dr": {
64
+ "label": "No DR",
65
+ "confidence": 85.5,
66
+ "note": "Mock prediction - model not loaded"
67
+ },
68
+ "dme": {
69
+ "label": "No DME",
70
+ "confidence": 90.2,
71
+ "note": "Mock prediction - model not loaded"
72
+ }
73
+ }
74
+
75
+ try:
76
+ x = preprocess(img)
77
+ preds = model.predict(x, verbose=0) # verbose=0 untuk suppress output
78
+
79
+ # Handle different model output formats
80
+ if isinstance(preds, dict):
81
+ dr = preds.get("dr_head", preds.get("DR", preds.get("output_0")))
82
+ dme = preds.get("dme_head", preds.get("DME", preds.get("output_1")))
83
+ elif isinstance(preds, list) and len(preds) >= 2:
84
+ dr = preds[0]
85
+ dme = preds[1]
86
+ else:
87
+ dr = preds[:, :5] if preds.shape[1] >= 5 else preds
88
+ dme = preds[:, 5:] if preds.shape[1] >= 8 else preds
89
+
90
+ # Pastikan shape benar
91
+ dr = dr[0] if len(dr.shape) > 1 else dr
92
+ dme = dme[0] if len(dme.shape) > 1 else dme
93
+
94
+ # Apply softmax
95
+ dr_probs = tf.nn.softmax(dr).numpy()
96
+ dme_probs = tf.nn.softmax(dme).numpy()
97
+
98
+ return {
99
+ "dr": {
100
+ "label": DR_CLASSES[int(np.argmax(dr_probs))],
101
+ "confidence": float(np.max(dr_probs) * 100)
102
+ },
103
+ "dme": {
104
+ "label": DME_CLASSES[int(np.argmax(dme_probs))],
105
+ "confidence": float(np.max(dme_probs) * 100)
106
+ }
107
+ }
108
+
109
+ except Exception as e:
110
+ return {
111
+ "error": str(e),
112
+ "note": "Prediction failed"
113
  }
 
 
 
 
 
 
 
 
 
 
114
 
115
+ # ===== GRADIO INTERFACE =====
116
  def gradio_predict(img):
117
+ if img is None:
118
+ return {"error": "No image provided"}
119
  return core_predict(img)
120
 
121
+ # Buat Gradio interface dengan tema yang lebih sederhana
122
  demo = gr.Interface(
123
  fn=gradio_predict,
124
+ inputs=gr.Image(type="pil", label="Upload Retina Image"),
125
+ outputs=gr.JSON(label="Prediction Results"),
126
+ title="Diabetic Retinopathy & DME Detection",
127
+ description="Upload a retina fundus image to detect Diabetic Retinopathy (DR) and Diabetic Macular Edema (DME)",
128
+ examples=[
129
+ ["sample1.jpg"], # Pastikan file contoh ada
130
+ ["sample2.jpg"]
131
+ ] if os.path.exists("sample1.jpg") else None,
132
+ allow_flagging="never"
133
  )
134
 
135
+ # Untuk Hugging Face, cukup export demo
136
+ if __name__ == "__main__":
137
+ demo.launch(debug=True)
138
+ else:
139
+ # Untuk Hugging Face deployment
140
+ demo.launch = lambda *args, **kwargs: demo