Spaces:
Build error
Build error
| # import torch | |
| # import torchvision.transforms as transforms | |
| # import gradio as gr | |
| # from model import load_model | |
| # CLASS_NAMES = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia'] | |
| # MODEL_PATH = "chexnet_epoch_17_auc_0.8457.pth" | |
| # # Load model | |
| # model = load_model(MODEL_PATH) | |
| # # Define the image transformation pipeline | |
| # def transform_image(image): | |
| # transformation_pipeline = transforms.Compose([ | |
| # transforms.Resize(256), | |
| # transforms.ToTensor(), | |
| # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| # ]) | |
| # return transformation_pipeline(image).unsqueeze(0) | |
| # # Define the prediction function | |
| # def predict(image): | |
| # pred = [] | |
| # img_tensor = transform_image(image) | |
| # with torch.no_grad(): | |
| # output = model(img_tensor) | |
| # values = output.squeeze().tolist() | |
| # prediction = torch.sigmoid(output).squeeze().tolist() | |
| # for i in range(len(CLASS_NAMES)): | |
| # pred.append({"disease": CLASS_NAMES[i], "model_value": values[i], "sigmoid_value": prediction[i]}) | |
| # return pred | |
| # # Create Gradio interface | |
| # demo = gr.Interface( | |
| # fn=predict, | |
| # inputs=gr.Image(type='pil', label="Upload Image"), | |
| # outputs=gr.JSON(), | |
| # api_name="predict" # Add this line | |
| # ) | |
| # demo.launch(share=True, show_error=True) | |
| import torch | |
| import torchvision.transforms as transforms | |
| import gradio as gr | |
| from PIL import Image | |
| import httpx | |
| from io import BytesIO | |
| from model import load_model | |
| CLASS_NAMES = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia'] | |
| MODEL_PATH = "chexnet_epoch_17_auc_0.8457.pth" | |
| # Load model | |
| model = load_model(MODEL_PATH) | |
| # Define the image transformation pipeline | |
| def transform_image(image): | |
| transformation_pipeline = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| return transformation_pipeline(image).unsqueeze(0) | |
| # Define the prediction function | |
| def predict(image_url): | |
| try: | |
| resp = httpx.get(image_url) | |
| resp.raise_for_status() | |
| image = Image.open(BytesIO(resp.content)).convert('RGB') | |
| except httpx.HTTPError as e: | |
| return f"Failed to fetch image from URL: {str(e)}" | |
| pred = [] | |
| img_tensor = transform_image(image) | |
| with torch.no_grad(): | |
| output = model(img_tensor) | |
| values = output.squeeze().tolist() | |
| prediction = torch.sigmoid(output).squeeze().tolist() | |
| for i in range(len(CLASS_NAMES)): | |
| pred.append({"disease": CLASS_NAMES[i], "model_value": values[i], "sigmoid_value": prediction[i]}) | |
| return pred | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(label="Image URL"), | |
| outputs=gr.JSON(), | |
| api_name="predict" | |
| ) | |
| demo.launch(share=True, show_error=True) | |