Spaces:
Sleeping
Sleeping
Set threshold 0.4, multilabel textbox hidden
Browse files
app.py
CHANGED
|
@@ -144,6 +144,9 @@ def get_img_array(img_path):
|
|
| 144 |
|
| 145 |
|
| 146 |
def get_prediction(img_path):
|
|
|
|
|
|
|
|
|
|
| 147 |
# check the image path
|
| 148 |
print(f"Image path: {img_path}")
|
| 149 |
# also display the original filename for info
|
|
@@ -156,7 +159,7 @@ def get_prediction(img_path):
|
|
| 156 |
# binary label
|
| 157 |
pred_binary = keras_binary_model(img_array, training=False)
|
| 158 |
print(f"Keras binary label: {pred_binary}")
|
| 159 |
-
|
| 160 |
fake = "Fake"
|
| 161 |
else:
|
| 162 |
fake = "Real"
|
|
@@ -165,7 +168,7 @@ def get_prediction(img_path):
|
|
| 165 |
pred_multi = keras_multi_model(img_array, training=False)
|
| 166 |
print(f"Keras multi label: {pred_multi}")
|
| 167 |
# Cut at the sigmoid 0.5 threshold
|
| 168 |
-
fake_parts = np.where(pred_multi >
|
| 169 |
print(f"Multi label: {fake_parts}")
|
| 170 |
# Format each of the fake face parts
|
| 171 |
parts_message = dict()
|
|
@@ -255,7 +258,8 @@ with gr.Blocks() as demo:
|
|
| 255 |
interactive=False, lines=2)
|
| 256 |
text_3 = gr.Text(
|
| 257 |
label="Multi label, Efficient net v2 B0",
|
| 258 |
-
interactive=False, lines=7
|
|
|
|
| 259 |
"""
|
| 260 |
text_3 = gr.Text(label="Sashi's model",
|
| 261 |
interactive=False, lines=3)
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
def get_prediction(img_path):
|
| 147 |
+
# adjust threshold for accuracy
|
| 148 |
+
threshold = 0.4
|
| 149 |
+
|
| 150 |
# check the image path
|
| 151 |
print(f"Image path: {img_path}")
|
| 152 |
# also display the original filename for info
|
|
|
|
| 159 |
# binary label
|
| 160 |
pred_binary = keras_binary_model(img_array, training=False)
|
| 161 |
print(f"Keras binary label: {pred_binary}")
|
| 162 |
+
if pred_binary[0][0] > threshold:
|
| 163 |
fake = "Fake"
|
| 164 |
else:
|
| 165 |
fake = "Real"
|
|
|
|
| 168 |
pred_multi = keras_multi_model(img_array, training=False)
|
| 169 |
print(f"Keras multi label: {pred_multi}")
|
| 170 |
# Cut at the sigmoid 0.5 threshold
|
| 171 |
+
fake_parts = np.where(pred_multi > threshold, 1, 0)
|
| 172 |
print(f"Multi label: {fake_parts}")
|
| 173 |
# Format each of the fake face parts
|
| 174 |
parts_message = dict()
|
|
|
|
| 258 |
interactive=False, lines=2)
|
| 259 |
text_3 = gr.Text(
|
| 260 |
label="Multi label, Efficient net v2 B0",
|
| 261 |
+
interactive=False, lines=7,
|
| 262 |
+
visible=False)
|
| 263 |
"""
|
| 264 |
text_3 = gr.Text(label="Sashi's model",
|
| 265 |
interactive=False, lines=3)
|