Spaces:
Runtime error
Runtime error
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from torchvision import transforms | |
| from TumorModel import TumorClassification, GliomaStageModel | |
| # βββ Load Tumor Classification Model βββββββββββββββββββββββββββββββββββββββββββ | |
| tumor_model = TumorClassification() | |
| sd = torch.load("BTD_model.pth", map_location="cpu") | |
| renamed_sd = {} | |
| for k, v in sd.items(): | |
| new_key = (k | |
| .replace("con1d.", "model.0.") | |
| .replace("con2d.", "model.3.") | |
| .replace("con3d.", "model.6.") | |
| .replace("fc1.", "model.8.") | |
| .replace("fc2.", "model.10.") | |
| .replace("output.", "model.12.")) | |
| renamed_sd[new_key] = v | |
| tumor_model.load_state_dict(renamed_sd) | |
| tumor_model.eval() | |
| # βββ Load Glioma Stage Model βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| glioma_model = GliomaStageModel() | |
| glioma_model.load_state_dict(torch.load("glioma_stages.pth", map_location="cpu")) | |
| glioma_model.eval() | |
| # βββ Labels and Image Transform βββββββββββββββββββββββββββββββββββββββββββββββ | |
| tumor_labels = ['glioma', 'meningioma', 'notumor', 'pituitary'] | |
| stage_labels = ['Stage 1', 'Stage 2'] # Or adjust to match your second model | |
| transform = transforms.Compose([ | |
| transforms.Grayscale(), | |
| transforms.Resize((208, 208)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5], std=[0.5]) | |
| ]) | |
| # βββ Gradio Prediction Functions βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_tumor(image): | |
| tensor = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| out = tumor_model(tensor) | |
| pred = torch.argmax(out, dim=1).item() | |
| return tumor_labels[pred] | |
| def predict_stage(gender, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca): | |
| gender_val = 0 if gender == "Male" else 1 | |
| features = [gender_val, age, idh1, tp53, atrx, pten, egfr, cic, pik3ca] | |
| x = torch.tensor(features).float().unsqueeze(0) | |
| with torch.no_grad(): | |
| out = glioma_model(x) | |
| pred = torch.argmax(out, dim=1).item() | |
| return stage_labels[pred] | |
| # βββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| tumor_tab = gr.Interface( | |
| fn=predict_tumor, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(), | |
| title="Brain Tumor Detector" | |
| ) | |
| stage_tab = gr.Interface( | |
| fn=predict_stage, | |
| inputs=[ | |
| gr.Radio(["Male", "Female"], label="Gender"), | |
| gr.Slider(0, 100, label="Age"), | |
| gr.Slider(0, 1, step=1, label="IDH1"), | |
| gr.Slider(0, 1, step=1, label="TP53"), | |
| gr.Slider(0, 1, step=1, label="ATRX"), | |
| gr.Slider(0, 1, step=1, label="PTEN"), | |
| gr.Slider(0, 1, step=1, label="EGFR"), | |
| gr.Slider(0, 1, step=1, label="CIC"), | |
| gr.Slider(0, 1, step=1, label="PIK3CA") | |
| ], | |
| outputs=gr.Label(), | |
| title="Glioma Stage Predictor" | |
| ) | |
| demo = gr.TabbedInterface([tumor_tab, stage_tab], tab_names=["Tumor Detector", "Glioma Stage"]) | |
| demo.launch() | |