Spaces:
Build error
Build error
| import sys | |
| import types | |
| sys.modules["audioop"] = types.ModuleType("audioop") | |
| sys.modules["pyaudioop"] = types.ModuleType("pyaudioop") | |
| import os | |
| import gradio as gr | |
| from PIL import Image | |
| from datasets import load_dataset | |
| import model_utils | |
| import visualization_utils | |
| # --- Configuration --- | |
| MODEL_NAME_OR_PATH = "google/vit-base-patch16-224-in21k" | |
| DATASET_PATH = "pawlo2013/chest_xray" | |
| MODEL_DIR = "./models" | |
| EXAMPLES_FOLDER = "./examples" | |
| # --- Load Data & Model --- | |
| print("Loading dataset information...") | |
| try: | |
| # We load the dataset mainly to get the class names correctly | |
| train_dataset = load_dataset(DATASET_PATH, split="train") | |
| class_names = train_dataset.features["label"].names | |
| print(f"Class names loaded: {class_names}") | |
| except Exception as e: | |
| print(f"Warning: Could not load dataset, using default class names. Error: {e}") | |
| # Fallback class names based on typical chest X-ray classification | |
| class_names = ["NORMAL", "PNEUMONIA"] | |
| print("Loading model and processor...") | |
| try: | |
| model, processor = model_utils.load_model_and_processor( | |
| MODEL_DIR, MODEL_NAME_OR_PATH, class_names | |
| ) | |
| print("Model and processor loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| raise e | |
| # --- Core Logic --- | |
| def classify_and_visualize(img, device="cpu"): | |
| if img is None: | |
| return None, None | |
| try: | |
| # Get predictions | |
| outputs, processed_input, probabilities, prediction_idx = model_utils.predict( | |
| model, processor, img, device | |
| ) | |
| # Format probabilities | |
| result = {class_name: prob for class_name, prob in zip(class_names, probabilities)} | |
| # Generate heatmap | |
| heatmap_img = visualization_utils.show_final_layer_attention_maps( | |
| outputs, processed_input, device | |
| ) | |
| return result, heatmap_img | |
| except Exception as e: | |
| print(f"Error in classification: {e}") | |
| # Return empty result and None for heatmap on error | |
| empty_result = {class_name: 0.0 for class_name in class_names} | |
| return empty_result, None | |
| def format_output(img): | |
| try: | |
| probs, heatmap = classify_and_visualize(img) | |
| return probs, heatmap | |
| except Exception as e: | |
| print(f"Error in format_output: {e}") | |
| # Return empty results on error | |
| empty_result = {class_name: 0.0 for class_name in class_names} | |
| return empty_result, None | |
| # --- Helper Functions --- | |
| def load_examples_from_folder(folder_path): | |
| examples = [] | |
| if os.path.exists(folder_path): | |
| for file in os.listdir(folder_path): | |
| if file.lower().endswith((".png", ".jpg", ".jpeg")): | |
| examples.append(os.path.join(folder_path, file)) | |
| return examples | |
| examples = load_examples_from_folder(EXAMPLES_FOLDER) | |
| # --- UI Layout --- | |
| title = "Pneumonia Detection Assistant" | |
| description = """ | |
| <div style="text-align: center; max-width: 700px; margin: 0 auto;"> | |
| <p>Upload a Chest X-Ray image to analyze it for signs of Pneumonia.</p> | |
| <p>The model classifies the image into <b>Normal</b>, <b>Viral Pneumonia</b>, or <b>Bacterial Pneumonia</b> categories | |
| and provides an attention heatmap to show which areas influenced the decision.</p> | |
| </div> | |
| """ | |
| article = """ | |
| <div style="border: 2px solid #e74c3c; padding: 20px; border-radius: 10px; margin-top: 20px; background-color: #fce4e4;"> | |
| <h3 style="color: #c0392b; margin-top: 0;">⚠️ MEDICAL DISCLAIMER</h3> | |
| <p style="color: #7f8c8d;"> | |
| This application uses Artificial Intelligence (Vision Transformer) for educational and research purposes only. | |
| <b>It is NOT a diagnostic tool.</b> The results generated by this model should not be treated as medical advice. | |
| Always consult with a qualified healthcare professional for medical diagnosis and treatment. | |
| </p> | |
| </div> | |
| """ | |
| theme = gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="slate", | |
| ).set( | |
| button_primary_background_fill="*primary_500", | |
| button_primary_background_fill_hover="*primary_600", | |
| ) | |
| iface = gr.Interface( | |
| fn=format_output, | |
| inputs=gr.Image(type="pil", label="Upload Chest X-Ray"), | |
| outputs=[ | |
| gr.Label(label="Prediction Confidence", num_top_classes=3), | |
| gr.Image(label="Attention Heatmap Analysis"), | |
| ], | |
| examples=examples, | |
| cache_examples=False, | |
| title=title, | |
| description=description, | |
| article=article, | |
| theme=theme, | |
| # allow_flagging="never" | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| iface.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) | |