Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| import cv2 | |
| import zipfile | |
| import os | |
| IMG_SIZE = 256 | |
| # Extract SavedModel | |
| if not os.path.exists("tuberculosis_cnn_saved_model"): | |
| print("Extracting SavedModel...") | |
| os.makedirs("tuberculosis_cnn_saved_model", exist_ok=True) | |
| with zipfile.ZipFile("tuberculosis_cnn_saved_model.zip", 'r') as zip_ref: | |
| zip_ref.extractall("tuberculosis_cnn_saved_model") | |
| if not os.path.exists("tuberculosis_cnn_saved_model/saved_model.pb"): | |
| raise FileNotFoundError("saved_model.pb not found") | |
| print("SavedModel extracted") | |
| print("Loading model...") | |
| model = tf.saved_model.load("tuberculosis_cnn_saved_model") | |
| infer = model.signatures['serving_default'] | |
| class_names = ['Normal', 'Tuberculosis'] | |
| print("Model loaded") | |
| def predict(image): | |
| try: | |
| img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) | |
| img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) / 255.0 | |
| img = img.reshape(1, IMG_SIZE, IMG_SIZE, 1).astype(np.float32) | |
| if img.max() > 1.0 or img.min() < 0.0: | |
| raise ValueError("Image not normalized") | |
| input_key = list(infer.structured_input_signature[1].keys())[0] | |
| inputs = {input_key: tf.convert_to_tensor(img)} | |
| outputs = infer(**inputs) | |
| output_key = list(outputs.keys())[0] | |
| pred = outputs[output_key].numpy() | |
| class_id = 1 if pred[0][0] > 0.5 else 0 | |
| confidence = pred[0][0] if class_id == 1 else 1 - pred[0][0] | |
| return f"Classified as: {class_names[class_id]}\nConfidence: {confidence:.4f}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Chest X-ray Image"), | |
| outputs=gr.Textbox(label="Classification Result"), | |
| title="🩺 Tuberculosis Classification CNN", | |
| description="Classify chest X-ray images as Normal or Tuberculosis (~93% accuracy).", | |
| examples=None, | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True) | |