TaliZG03's picture
Update app.py
fbfca97 verified
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
import keras
from huggingface_hub import hf_hub_download
import zipfile
import h5py
import traceback
import shutil
print("TF:", tf.__version__, flush=True)
print("Keras:", keras.__version__, flush=True)
REPO_ID = "TaliZG03/kidney_normal_CT_classifier_model"
MODEL_FILENAME = "model.keras"
# -------------------------
# 1) Download the broken .keras (we only need its weights file)
# -------------------------
model_zip = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
print("Downloaded model.keras:", model_zip, flush=True)
# -------------------------
# 2) Extract model.weights.h5 from the .keras zip
# -------------------------
extract_dir = "/tmp/extracted"
shutil.rmtree(extract_dir, ignore_errors=True)
os.makedirs(extract_dir, exist_ok=True)
weights_path = os.path.join(extract_dir, "model.weights.h5")
with zipfile.ZipFile(model_zip, "r") as z:
print("Archive contents:", z.namelist(), flush=True)
z.extract("model.weights.h5", extract_dir)
print("Extracted weights:", weights_path, flush=True)
# -------------------------
# 3) Inspect the weights file to understand the architecture
# (This prints the top-level H5 groups and some dataset keys)
# -------------------------
def inspect_h5(h5_path: str, max_root=120, max_datasets=60):
print("\n=== H5 INSPECTION ===", flush=True)
with h5py.File(h5_path, "r") as f:
root_keys = list(f.keys())
print("H5 root keys count:", len(root_keys), flush=True)
print("H5 root keys (first):", root_keys[:max_root], flush=True)
datasets = []
def visitor(name, obj):
if isinstance(obj, h5py.Dataset):
datasets.append(name)
f.visititems(visitor)
print("\nDataset count:", len(datasets), flush=True)
print("Dataset names (first):", datasets[:max_datasets], flush=True)
print("=== END H5 INSPECTION ===\n", flush=True)
inspect_h5(weights_path)
# -------------------------
# 4) Rebuild your architecture (PLACEHOLDER)
# IMPORTANT:
# - You MUST match the original training architecture exactly.
# - This is a best-guess template.
# - We use Rescaling instead of Normalization to avoid missing mean/var/count.
# -------------------------
def build_model(input_shape=(224, 224, 3), num_classes=1, backbone="EfficientNetB3"):
"""
Try common backbones by changing `backbone`:
- "EfficientNetB0", "EfficientNetB1", "EfficientNetB2", "EfficientNetB3", ...
- "MobileNetV2"
Also adjust:
- input_shape (CT might be (512,512,1) or (224,224,3))
- num_classes (1 for binary sigmoid, >1 for softmax)
"""
inputs = keras.Input(shape=input_shape, name="input")
# Safe preprocessing layer (no saved variables like Normalization)
x = keras.layers.Rescaling(1.0 / 255.0, name="rescaling")(inputs)
# Choose backbone
if backbone.startswith("EfficientNet"):
base_cls = getattr(keras.applications, backbone)
base = base_cls(
include_top=False,
weights=None, # we load our weights
input_tensor=x,
)
x = base.output
elif backbone == "MobileNetV2":
base = keras.applications.MobileNetV2(
include_top=False,
weights=None,
input_tensor=x,
)
x = base.output
else:
raise ValueError(f"Unknown backbone: {backbone}")
x = keras.layers.GlobalAveragePooling2D(name="gap")(x)
if num_classes == 1:
outputs = keras.layers.Dense(1, activation="sigmoid", name="pred")(x)
else:
outputs = keras.layers.Dense(num_classes, activation="softmax", name="pred")(x)
return keras.Model(inputs, outputs, name=f"{backbone}_classifier")
# -------------------------
# 5) Build + load weights
# -------------------------
# Try a few likely backbones automatically.
# If one matches, load_weights will succeed.
CANDIDATE_BACKBONES = [
"EfficientNetB0",
"EfficientNetB1",
"EfficientNetB2",
"EfficientNetB3",
"EfficientNetB4",
"MobileNetV2",
]
# Change these if needed
INPUT_SHAPE = (224, 224, 3) # update if your CT pipeline differs
NUM_CLASSES = 1 # 1 = binary sigmoid, set >1 for multi-class
loaded = False
last_error = None
for bb in CANDIDATE_BACKBONES:
print(f"\n--- Trying backbone: {bb} ---", flush=True)
try:
model = build_model(input_shape=INPUT_SHAPE, num_classes=NUM_CLASSES, backbone=bb)
print("Built model. Layers:", len(model.layers), flush=True)
# Strict loading first
model.load_weights(weights_path)
print(f"✅ Weights loaded successfully with {bb}!", flush=True)
loaded = True
chosen_backbone = bb
break
except Exception as e:
last_error = e
print(f"❌ load_weights failed for {bb}.", flush=True)
# Print traceback text (safe)
print(traceback.format_exc(), flush=True)
if not loaded:
print("\n❗ Could not match weights with any candidate backbone.", flush=True)
print("Last error type:", type(last_error).__name__ if last_error else None, flush=True)
raise RuntimeError(
"Architecture mismatch. Use the printed H5 keys above to identify the real backbone "
"and update build_model() accordingly (input shape, backbone, head)."
)
print("\n✅ Model ready for inference with backbone:", chosen_backbone, flush=True)
# OPTIONAL: test a dummy forward pass (adjust shape if needed)
try:
dummy = tf.zeros((1,) + INPUT_SHAPE, dtype=tf.float32)
y = model(dummy, training=False)
print("Dummy output shape:", y.shape, flush=True)
except Exception:
print("Dummy forward failed (may indicate input_shape mismatch).", flush=True)
print(traceback.format_exc(), flush=True)
# --------------------
# Preprocess
# --------------------
def preprocess(image: Image.Image) -> np.ndarray:
image = image.resize(IMG_SIZE).convert("RGB")
x = np.asarray(image, dtype=np.float32) / 255.0
return np.expand_dims(x, axis=0)
# --------------------
# Predict
# --------------------
def predict(image):
# Gradio can pass None if user clicks without uploading or upload fails
if image is None:
return "Please upload an image first."
x = preprocess(image)
pred = float(model.predict(x, verbose=0)[0][0])
# NOTE: Keeping your original logic:
# pred >= 0.5 -> NORMAL, else ABNORMAL
label = "NORMAL" if pred >= THRESHOLD else "ABNORMAL"
confidence = pred if label == "NORMAL" else (1.0 - pred)
if label == "NORMAL" and confidence >= 0.7:
explanation = "✅ The kidney CT scan appears normal with high confidence."
attention_flag = ""
elif label == "NORMAL":
explanation = "⚠️ The scan appears normal, but the model's confidence is low. Consider radiologist review."
attention_flag = "🚨 FLAGGED FOR RADIOLOGIST REVIEW"
else:
explanation = "⚠️ The kidney CT scan shows signs of abnormality. Immediate radiologist attention is recommended."
attention_flag = "🚨 FLAGGED FOR RADIOLOGIST REVIEW"
return (
f"Prediction: {label}\n"
f"Model output: {pred:.4f}\n"
f"Confidence: {confidence:.2%}\n\n"
f"{explanation}\n"
f"{attention_flag}"
)
# --------------------
# Gradio UI
# --------------------
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload CT image"),
outputs=gr.Textbox(label="Result", lines=8),
title="Kidney CT Classifier",
description="Upload a kidney CT image. The model predicts if it's NORMAL or ABNORMAL."
)
# --------------------
# Launch (Spaces-safe)
# --------------------
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", "7860")),
)