iruda21cse's picture
changed input type from img to img url
c60db81
# import torch
# import torchvision.transforms as transforms
# import gradio as gr
# from model import load_model
# CLASS_NAMES = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia']
# MODEL_PATH = "chexnet_epoch_17_auc_0.8457.pth"
# # Load model
# model = load_model(MODEL_PATH)
# # Define the image transformation pipeline
# def transform_image(image):
# transformation_pipeline = transforms.Compose([
# transforms.Resize(256),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
# ])
# return transformation_pipeline(image).unsqueeze(0)
# # Define the prediction function
# def predict(image):
# pred = []
# img_tensor = transform_image(image)
# with torch.no_grad():
# output = model(img_tensor)
# values = output.squeeze().tolist()
# prediction = torch.sigmoid(output).squeeze().tolist()
# for i in range(len(CLASS_NAMES)):
# pred.append({"disease": CLASS_NAMES[i], "model_value": values[i], "sigmoid_value": prediction[i]})
# return pred
# # Create Gradio interface
# demo = gr.Interface(
# fn=predict,
# inputs=gr.Image(type='pil', label="Upload Image"),
# outputs=gr.JSON(),
# api_name="predict" # Add this line
# )
# demo.launch(share=True, show_error=True)
import torch
import torchvision.transforms as transforms
import gradio as gr
from PIL import Image
import httpx
from io import BytesIO
from model import load_model
CLASS_NAMES = ['Atelectasis', 'Consolidation', 'Infiltration', 'Pneumothorax', 'Edema', 'Emphysema', 'Fibrosis', 'Effusion', 'Pneumonia', 'Pleural_Thickening', 'Cardiomegaly', 'Nodule', 'Mass', 'Hernia']
MODEL_PATH = "chexnet_epoch_17_auc_0.8457.pth"
# Load model
model = load_model(MODEL_PATH)
# Define the image transformation pipeline
def transform_image(image):
transformation_pipeline = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
return transformation_pipeline(image).unsqueeze(0)
# Define the prediction function
def predict(image_url):
try:
resp = httpx.get(image_url)
resp.raise_for_status()
image = Image.open(BytesIO(resp.content)).convert('RGB')
except httpx.HTTPError as e:
return f"Failed to fetch image from URL: {str(e)}"
pred = []
img_tensor = transform_image(image)
with torch.no_grad():
output = model(img_tensor)
values = output.squeeze().tolist()
prediction = torch.sigmoid(output).squeeze().tolist()
for i in range(len(CLASS_NAMES)):
pred.append({"disease": CLASS_NAMES[i], "model_value": values[i], "sigmoid_value": prediction[i]})
return pred
# Create Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(label="Image URL"),
outputs=gr.JSON(),
api_name="predict"
)
demo.launch(share=True, show_error=True)