Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from PIL import Image | |
| from sklearn.preprocessing import StandardScaler | |
| import joblib | |
| import os | |
| # Disable GPU for TensorFlow to avoid CUDA conflicts | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| # Set PyTorch device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load trained ViT model (PyTorch) | |
| vit_model = models.vit_b_16(weights="DEFAULT") # Fixed deprecated 'pretrained' | |
| vit_model.heads = torch.nn.Linear(in_features=768, out_features=2) # Binary classification | |
| # Load ViT model weights (if available) | |
| vit_model_path = "vit_bc.pth" | |
| if os.path.exists(vit_model_path): | |
| vit_model.load_state_dict(torch.load(vit_model_path, map_location=device)) | |
| vit_model.to(device) | |
| vit_model.eval() | |
| # Define image transformations for ViT | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| # Class labels | |
| class_names = ["Benign", "Malignant"] | |
| # Load trained Neural Network model (TensorFlow/Keras) | |
| nn_model_path = "my_NN_BC_model.keras" | |
| nn_model = tf.keras.models.load_model(nn_model_path) | |
| if os.path.exists(nn_model_path): | |
| try: | |
| nn_model = tf.keras.models.load_model(nn_model_path) | |
| except Exception as e: | |
| print(f"Error loading NN model: {e}") | |
| # Load scaler for feature normalization | |
| scaler_path = "nn_bc_scaler.pkl" | |
| scaler = joblib.load(scaler_path) if os.path.exists(scaler_path) else None | |
| # Feature names | |
| feature_names = [ | |
| "Mean Radius", "Mean Texture", "Mean Perimeter", "Mean Area", "Mean Smoothness", | |
| "Mean Compactness", "Mean Concavity", "Mean Concave Points", "Mean Symmetry", "Mean Fractal Dimension", | |
| "SE Radius", "SE Texture", "SE Perimeter", "SE Area", "SE Smoothness", | |
| "SE Compactness", "SE Concavity", "SE Concave Points", "SE Symmetry", "SE Fractal Dimension", | |
| "Worst Radius", "Worst Texture", "Worst Perimeter", "Worst Area", "Worst Smoothness", | |
| "Worst Compactness", "Worst Concavity", "Worst Concave Points", "Worst Symmetry", "Worst Fractal Dimension" | |
| ] | |
| # Example inputs | |
| benign_example = [9.504,12.44,60.34,273.9,0.1024,0.06492,0.02956,0.02076,0.1815,0.06905,0.2773,0.9768, | |
| 1.909,15.7,0.009606,0.01432,0.01985,0.01421,0.02027,0.002968,10.23,15.66,65.13,314.9, | |
| 0.1324,0.1148,0.08867,0.06227,0.245,0.07773] | |
| malignant_example = [11.42,20.38,77.58,386.1,0.1425,0.2839,0.2414,0.1052,0.2597,0.09744,0.4956,1.156, | |
| 3.445,27.23,0.00911,0.07458,0.05661,0.01867,0.05963,0.009208,14.91,26.5,98.87,567.7, | |
| 0.2098,0.8663,0.6869,0.2575,0.6638,0.173] | |
| def classify(model_choice, image=None, *features): | |
| """Classify using ViT (image) or NN (features).""" | |
| if model_choice == "ViT": | |
| if image is None: | |
| return "β Please upload an image for ViT classification." | |
| image = image.convert("RGB") | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = vit_model(input_tensor) | |
| predicted_class = torch.argmax(output, dim=1).item() | |
| return class_names[predicted_class] | |
| elif model_choice == "Neural Network": | |
| if any(f is None for f in features): | |
| return "β Please enter all 30 numerical features." | |
| input_data = np.array(features).reshape(1, -1) | |
| input_data_std = scaler.transform(input_data) if scaler else input_data | |
| prediction = nn_model.predict(input_data_std) if nn_model else [[0, 1]] | |
| predicted_class = np.argmax(prediction) | |
| return class_names[predicted_class] | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π©Ί Breast Cancer Classification Model") | |
| gr.Markdown("Select a model and provide input data to classify breast cancer as **Benign** or **Malignant**.") | |
| with gr.Row(): | |
| model_selector = gr.Radio(["ViT", "Neural Network"], label="π¬ Choose Model", value="ViT") | |
| image_input = gr.Image(type="pil", label="π· Upload Image (for ViT)", visible=True) | |
| example_image = {"π΅ Benign Example Image": "images/benign (1)_aug_0.png", | |
| "π΄ Malignant Example Image": "images/malignant (1)_aug_0.png"} | |
| with gr.Row(visible=True) as example_image_row: | |
| example_buttons = [] | |
| for label, path in example_image.items(): | |
| with gr.Column(): | |
| gr.Image(value=path, label=label, interactive=False, height=100) | |
| btn = gr.Button(f"Use {label.split()[1]}") | |
| example_buttons.append((btn, path)) | |
| feature_inputs = [gr.Number(label=feature) for feature in feature_names] | |
| # Organizing feature inputs into rows of 3 columns | |
| with gr.Row(): | |
| with gr.Column(): | |
| for i in range(0, len(feature_inputs), 3): | |
| gr.Row([feature_inputs[j] for j in range(i, min(i+3, len(feature_inputs)))]) | |
| # Example buttons | |
| def fill_example(example): | |
| """Pre-fills example inputs.""" | |
| return {feature_inputs[i]: example[i] for i in range(len(feature_inputs))} | |
| with gr.Row(): | |
| example_btn_1 = gr.Button("π΄ Malignant Example") | |
| example_btn_2 = gr.Button("π΅ Benign Example") | |
| output_text = gr.Textbox(label="π Model Prediction", interactive=False) | |
| def extract_features_from_file(file): | |
| """Reads a text file and extracts numerical features.""" | |
| if file is None: | |
| return "β Please upload a valid feature file." | |
| try: | |
| # Read and process file contents | |
| content = file.read().decode("utf-8").strip() | |
| values = [float(x) for x in content.replace(",", " ").split()] | |
| # Check if we have exactly 30 features | |
| if len(values) != 30: | |
| return "β The file must contain exactly 30 numerical values." | |
| return values | |
| # return {feature_inputs[i]: values[i] for i in range(30)} | |
| except Exception as e: | |
| return f"β Error processing file: {e}" | |
| # Add file upload component | |
| file_input = gr.File(label="π Upload Feature File (for NN)", type="binary", visible=False) | |
| # Update UI logic to show file input for NN model | |
| def toggle_inputs(choice): | |
| image_visibility = choice == "ViT" | |
| feature_visibility = choice == "Neural Network" | |
| file_visibility = choice == "Neural Network" | |
| return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs) + [gr.update(visible=file_visibility)] | |
| model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs, file_input]) | |
| # Process uploaded file and populate feature fields | |
| file_input.change(extract_features_from_file, inputs=file_input, outputs=feature_inputs) | |
| # Toggle input fields based on model selection | |
| """Toggle visibility of inputs based on model selection.""" | |
| def toggle_inputs(choice): | |
| image_visibility = choice == "ViT" | |
| feature_visibility = choice == "Neural Network" | |
| return [gr.update(visible=image_visibility)] + [gr.update(visible=feature_visibility)] * len(feature_inputs) | |
| # model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs]) | |
| model_selector.change(toggle_inputs, model_selector, [image_input, *feature_inputs, file_input]) | |
| # Bind image preview buttons to update image_input | |
| for btn, img_path in example_buttons: | |
| btn.click(lambda p=img_path: Image.open(p), outputs=image_input) | |
| example_btn_1.click(lambda: fill_example(benign_example), None, feature_inputs) | |
| example_btn_2.click(lambda: fill_example(malignant_example), None, feature_inputs) | |
| classify_button = gr.Button("π Classify") | |
| classify_button.click(classify, [model_selector, image_input] + feature_inputs, output_text) | |
| demo.launch() |