dewifaj commited on
Commit
8352b63
·
verified ·
1 Parent(s): b676b53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -34
app.py CHANGED
@@ -1,35 +1,23 @@
1
- import streamlit as st
2
- from PIL import Image
3
- from transformers import AutoModelForImageClassification, AutoFeatureExtractor
4
- import torch
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Load the model and feature extractor
7
- model_name = "dewifaj/resnet18_alzheimer_classifier"
8
- model = AutoModelForImageClassification.from_pretrained(model_name)
9
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
10
-
11
- # Define the label mapping
12
- label_mapping = model.config.id2label
13
-
14
- def predict(image):
15
- inputs = feature_extractor(images=image, return_tensors="pt")
16
- with torch.no_grad():
17
- outputs = model(**inputs)
18
- logits = outputs.logits
19
- predicted_class_idx = logits.argmax(-1).item()
20
- return label_mapping[predicted_class_idx]
21
-
22
- # Streamlit app
23
- st.title("Alzheimer Image Classification")
24
- st.write("Upload an image to classify the stage of Alzheimer's disease.")
25
-
26
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
27
-
28
- if uploaded_file is not None:
29
- image = Image.open(uploaded_file)
30
- st.image(image, caption='Uploaded Image', use_column_width=True)
31
- st.write("")
32
- st.write("Classifying...")
33
-
34
- label = predict(image)
35
- st.write(f"The model predicts: **{label}**")
 
1
+ from transformers import pipeline
2
+ import gradio as gr
3
+ def alz_mri_classification(image):
4
+ classifier = pipeline("image-classification", model="dewifaj/alzheimer_classification")
5
+ result = classifier(image)
6
+ # extract the highest score
7
+ prediction = result[0]
8
+ score = prediction['score']
9
+ label = prediction['label']
10
+ return {"score": score, "label": label}
11
+
12
+ example_image_paths = ["Very_Mild_Demented.png",
13
+ "Mild_Demented.png",
14
+ "Moderate_Demented.png",
15
+ "Non_Demented.png"]
16
 
17
+ image_input = gr.Image(type="pil", label="Upload Image")
18
+ iface = gr.Interface(fn=alz_mri_classification,
19
+ inputs=image_input,
20
+ outputs="json",
21
+ examples=example_image_paths,
22
+ title="Alzheimer Recognition from MRI")
23
+ iface.launch()