Spaces:
Runtime error
Runtime error
| import os | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
| import tensorflow as tf | |
| import keras | |
| from huggingface_hub import hf_hub_download | |
| import zipfile | |
| import h5py | |
| import traceback | |
| import shutil | |
| print("TF:", tf.__version__, flush=True) | |
| print("Keras:", keras.__version__, flush=True) | |
| REPO_ID = "TaliZG03/kidney_normal_CT_classifier_model" | |
| MODEL_FILENAME = "model.keras" | |
| # ------------------------- | |
| # 1) Download the broken .keras (we only need its weights file) | |
| # ------------------------- | |
| model_zip = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) | |
| print("Downloaded model.keras:", model_zip, flush=True) | |
| # ------------------------- | |
| # 2) Extract model.weights.h5 from the .keras zip | |
| # ------------------------- | |
| extract_dir = "/tmp/extracted" | |
| shutil.rmtree(extract_dir, ignore_errors=True) | |
| os.makedirs(extract_dir, exist_ok=True) | |
| weights_path = os.path.join(extract_dir, "model.weights.h5") | |
| with zipfile.ZipFile(model_zip, "r") as z: | |
| print("Archive contents:", z.namelist(), flush=True) | |
| z.extract("model.weights.h5", extract_dir) | |
| print("Extracted weights:", weights_path, flush=True) | |
| # ------------------------- | |
| # 3) Inspect the weights file to understand the architecture | |
| # (This prints the top-level H5 groups and some dataset keys) | |
| # ------------------------- | |
| def inspect_h5(h5_path: str, max_root=120, max_datasets=60): | |
| print("\n=== H5 INSPECTION ===", flush=True) | |
| with h5py.File(h5_path, "r") as f: | |
| root_keys = list(f.keys()) | |
| print("H5 root keys count:", len(root_keys), flush=True) | |
| print("H5 root keys (first):", root_keys[:max_root], flush=True) | |
| datasets = [] | |
| def visitor(name, obj): | |
| if isinstance(obj, h5py.Dataset): | |
| datasets.append(name) | |
| f.visititems(visitor) | |
| print("\nDataset count:", len(datasets), flush=True) | |
| print("Dataset names (first):", datasets[:max_datasets], flush=True) | |
| print("=== END H5 INSPECTION ===\n", flush=True) | |
| inspect_h5(weights_path) | |
| # ------------------------- | |
| # 4) Rebuild your architecture (PLACEHOLDER) | |
| # IMPORTANT: | |
| # - You MUST match the original training architecture exactly. | |
| # - This is a best-guess template. | |
| # - We use Rescaling instead of Normalization to avoid missing mean/var/count. | |
| # ------------------------- | |
| def build_model(input_shape=(224, 224, 3), num_classes=1, backbone="EfficientNetB3"): | |
| """ | |
| Try common backbones by changing `backbone`: | |
| - "EfficientNetB0", "EfficientNetB1", "EfficientNetB2", "EfficientNetB3", ... | |
| - "MobileNetV2" | |
| Also adjust: | |
| - input_shape (CT might be (512,512,1) or (224,224,3)) | |
| - num_classes (1 for binary sigmoid, >1 for softmax) | |
| """ | |
| inputs = keras.Input(shape=input_shape, name="input") | |
| # Safe preprocessing layer (no saved variables like Normalization) | |
| x = keras.layers.Rescaling(1.0 / 255.0, name="rescaling")(inputs) | |
| # Choose backbone | |
| if backbone.startswith("EfficientNet"): | |
| base_cls = getattr(keras.applications, backbone) | |
| base = base_cls( | |
| include_top=False, | |
| weights=None, # we load our weights | |
| input_tensor=x, | |
| ) | |
| x = base.output | |
| elif backbone == "MobileNetV2": | |
| base = keras.applications.MobileNetV2( | |
| include_top=False, | |
| weights=None, | |
| input_tensor=x, | |
| ) | |
| x = base.output | |
| else: | |
| raise ValueError(f"Unknown backbone: {backbone}") | |
| x = keras.layers.GlobalAveragePooling2D(name="gap")(x) | |
| if num_classes == 1: | |
| outputs = keras.layers.Dense(1, activation="sigmoid", name="pred")(x) | |
| else: | |
| outputs = keras.layers.Dense(num_classes, activation="softmax", name="pred")(x) | |
| return keras.Model(inputs, outputs, name=f"{backbone}_classifier") | |
| # ------------------------- | |
| # 5) Build + load weights | |
| # ------------------------- | |
| # Try a few likely backbones automatically. | |
| # If one matches, load_weights will succeed. | |
| CANDIDATE_BACKBONES = [ | |
| "EfficientNetB0", | |
| "EfficientNetB1", | |
| "EfficientNetB2", | |
| "EfficientNetB3", | |
| "EfficientNetB4", | |
| "MobileNetV2", | |
| ] | |
| # Change these if needed | |
| INPUT_SHAPE = (224, 224, 3) # update if your CT pipeline differs | |
| NUM_CLASSES = 1 # 1 = binary sigmoid, set >1 for multi-class | |
| loaded = False | |
| last_error = None | |
| for bb in CANDIDATE_BACKBONES: | |
| print(f"\n--- Trying backbone: {bb} ---", flush=True) | |
| try: | |
| model = build_model(input_shape=INPUT_SHAPE, num_classes=NUM_CLASSES, backbone=bb) | |
| print("Built model. Layers:", len(model.layers), flush=True) | |
| # Strict loading first | |
| model.load_weights(weights_path) | |
| print(f"✅ Weights loaded successfully with {bb}!", flush=True) | |
| loaded = True | |
| chosen_backbone = bb | |
| break | |
| except Exception as e: | |
| last_error = e | |
| print(f"❌ load_weights failed for {bb}.", flush=True) | |
| # Print traceback text (safe) | |
| print(traceback.format_exc(), flush=True) | |
| if not loaded: | |
| print("\n❗ Could not match weights with any candidate backbone.", flush=True) | |
| print("Last error type:", type(last_error).__name__ if last_error else None, flush=True) | |
| raise RuntimeError( | |
| "Architecture mismatch. Use the printed H5 keys above to identify the real backbone " | |
| "and update build_model() accordingly (input shape, backbone, head)." | |
| ) | |
| print("\n✅ Model ready for inference with backbone:", chosen_backbone, flush=True) | |
| # OPTIONAL: test a dummy forward pass (adjust shape if needed) | |
| try: | |
| dummy = tf.zeros((1,) + INPUT_SHAPE, dtype=tf.float32) | |
| y = model(dummy, training=False) | |
| print("Dummy output shape:", y.shape, flush=True) | |
| except Exception: | |
| print("Dummy forward failed (may indicate input_shape mismatch).", flush=True) | |
| print(traceback.format_exc(), flush=True) | |
| # -------------------- | |
| # Preprocess | |
| # -------------------- | |
| def preprocess(image: Image.Image) -> np.ndarray: | |
| image = image.resize(IMG_SIZE).convert("RGB") | |
| x = np.asarray(image, dtype=np.float32) / 255.0 | |
| return np.expand_dims(x, axis=0) | |
| # -------------------- | |
| # Predict | |
| # -------------------- | |
| def predict(image): | |
| # Gradio can pass None if user clicks without uploading or upload fails | |
| if image is None: | |
| return "Please upload an image first." | |
| x = preprocess(image) | |
| pred = float(model.predict(x, verbose=0)[0][0]) | |
| # NOTE: Keeping your original logic: | |
| # pred >= 0.5 -> NORMAL, else ABNORMAL | |
| label = "NORMAL" if pred >= THRESHOLD else "ABNORMAL" | |
| confidence = pred if label == "NORMAL" else (1.0 - pred) | |
| if label == "NORMAL" and confidence >= 0.7: | |
| explanation = "✅ The kidney CT scan appears normal with high confidence." | |
| attention_flag = "" | |
| elif label == "NORMAL": | |
| explanation = "⚠️ The scan appears normal, but the model's confidence is low. Consider radiologist review." | |
| attention_flag = "🚨 FLAGGED FOR RADIOLOGIST REVIEW" | |
| else: | |
| explanation = "⚠️ The kidney CT scan shows signs of abnormality. Immediate radiologist attention is recommended." | |
| attention_flag = "🚨 FLAGGED FOR RADIOLOGIST REVIEW" | |
| return ( | |
| f"Prediction: {label}\n" | |
| f"Model output: {pred:.4f}\n" | |
| f"Confidence: {confidence:.2%}\n\n" | |
| f"{explanation}\n" | |
| f"{attention_flag}" | |
| ) | |
| # -------------------- | |
| # Gradio UI | |
| # -------------------- | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload CT image"), | |
| outputs=gr.Textbox(label="Result", lines=8), | |
| title="Kidney CT Classifier", | |
| description="Upload a kidney CT image. The model predicts if it's NORMAL or ABNORMAL." | |
| ) | |
| # -------------------- | |
| # Launch (Spaces-safe) | |
| # -------------------- | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("PORT", "7860")), | |
| ) | |