Sazzz02 commited on
Commit
3771187
·
verified ·
1 Parent(s): 0c7ebff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -20
app.py CHANGED
@@ -1,30 +1,34 @@
 
 
 
1
  import gradio as gr
2
- from transformers import AutoModelForImageClassification, AutoImageProcessor
3
- from PIL import Image
4
- import torch
5
 
6
- # Load the working public model
7
- model_name = "projectswithjam/skin-condition-classifier"
8
- model = AutoModelForImageClassification.from_pretrained(model_name)
9
- processor = AutoImageProcessor.from_pretrained(model_name)
10
- labels = model.config.id2label
 
 
 
 
 
11
 
12
  # Prediction function
13
- def predict(image):
14
- image = image.convert("RGB")
15
- inputs = processor(images=image, return_tensors="pt")
16
- with torch.no_grad():
17
- outputs = model(**inputs)
18
- logits = outputs.logits
19
- predicted_idx = logits.argmax(-1).item()
20
- label = labels[str(predicted_idx)]
21
- return f"Prediction: {label}"
22
 
23
- # Gradio Interface
24
  gr.Interface(
25
  fn=predict,
26
  inputs=gr.Image(type="pil"),
27
  outputs="text",
28
- title="Skin Disease Classifier",
29
- description="Upload a skin or scalp image. This model predicts the skin condition using a transformer trained on DermNet data."
30
  ).launch()
 
 
1
+ import numpy as np
2
+ from keras.models import load_model
3
+ from keras.preprocessing import image
4
  import gradio as gr
 
 
 
5
 
6
+ # Load the model (ensure this .h5 file is in the same folder)
7
+ model = load_model("VGG16-Final.h5")
8
+
9
+ # Class labels (update if different)
10
+ class_names = [
11
+ 'Alopecia Areata', 'Contact Dermatitis', 'Folliculitis',
12
+ 'Head Lice', 'Lichen Planus', 'Male Pattern Baldness',
13
+ 'Psoriasis', 'Seborrheic Dermatitis', 'Telogen Effluvium',
14
+ 'Tinea Capitis'
15
+ ]
16
 
17
  # Prediction function
18
+ def predict(img):
19
+ img = img.resize((224, 224)) # Match input shape of model
20
+ img_array = image.img_to_array(img)
21
+ img_array = np.expand_dims(img_array, axis=0) / 255.0 # Normalize
22
+ prediction = model.predict(img_array)
23
+ predicted_class = class_names[np.argmax(prediction)]
24
+ return f"Prediction: {predicted_class}"
 
 
25
 
26
+ # Gradio interface
27
  gr.Interface(
28
  fn=predict,
29
  inputs=gr.Image(type="pil"),
30
  outputs="text",
31
+ title="Hair/Scalp Disease Classifier",
32
+ description="Upload a scalp image to classify the condition using a VGG16-based CNN model."
33
  ).launch()
34
+