ehsanwebdev99 commited on
Commit
a0dcc35
·
verified ·
1 Parent(s): b3bface

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -27
app.py CHANGED
@@ -4,12 +4,38 @@ import torch
4
  from torchvision import transforms
5
  from transformers import AutoModelForImageClassification
6
 
7
- # Load the model (no AutoImageProcessor since it is unsupported)
8
- model_name = "anismizi/skin-type-classifier"
9
- model = AutoModelForImageClassification.from_pretrained(model_name)
10
- model.eval() # Set model to eval mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- # Define manual preprocessing transforms similar to ResNet50 training
 
 
 
 
 
 
 
13
  preprocess = transforms.Compose([
14
  transforms.Resize(256),
15
  transforms.CenterCrop(224),
@@ -18,37 +44,33 @@ preprocess = transforms.Compose([
18
  std=[0.229, 0.224, 0.225]),
19
  ])
20
 
21
- # Labels according to model info
22
- labels = ["dry", "oily"]
23
-
24
  def analyze_skin(image: Image.Image):
25
- # Convert input image to RGB
26
  image = image.convert("RGB")
27
- # Preprocess image
28
  input_tensor = preprocess(image)
29
- # Add batch dimension
30
- input_batch = input_tensor.unsqueeze(0)
31
 
32
- # Run inference
33
  with torch.no_grad():
34
- outputs = model(input_batch)
35
- logits = outputs.logits
36
- probabilities = torch.nn.functional.softmax(logits, dim=1)
37
- confidence, predicted_idx = torch.max(probabilities, dim=1)
38
-
39
- predicted_label = labels[predicted_idx.item()]
40
- confidence_score = confidence.item()
41
-
42
- # Format results for display and API output
43
- return {predicted_label: confidence_score}
 
 
 
44
 
45
- # Create Gradio interface
46
  iface = gr.Interface(
47
  fn=analyze_skin,
48
  inputs=gr.Image(type="pil"),
49
- outputs=gr.Label(num_top_classes=2),
50
- title="Skin Condition Analyzer",
51
- description="Classify skin as dry or oily from an image."
52
  )
53
 
54
  if __name__ == "__main__":
 
4
  from torchvision import transforms
5
  from transformers import AutoModelForImageClassification
6
 
7
+ # Define model names and corresponding labels
8
+ MODEL_CONFIGS = [
9
+ {
10
+ "name": "anismizi/skin-type-classifier",
11
+ "labels": ["dry", "oily"],
12
+ "key": "oil_vs_dry"
13
+ },
14
+ {
15
+ "name": "naamalia23/acne-severity-classification",
16
+ "labels": ["no_acne", "acne"],
17
+ "key": "acne"
18
+ },
19
+ {
20
+ "name": "Siraja704/DermaAI",
21
+ "labels": ["no_redness", "redness"],
22
+ "key": "redness"
23
+ },
24
+ {
25
+ "name": "imfarzanansari/skintelligent-wrinkles",
26
+ "labels": ["no_wrinkles", "wrinkles"],
27
+ "key": "wrinkles"
28
+ },
29
+ ]
30
 
31
+ # Load all models at startup
32
+ MODELS = []
33
+ for config in MODEL_CONFIGS:
34
+ model = AutoModelForImageClassification.from_pretrained(config["name"])
35
+ model.eval()
36
+ MODELS.append(model)
37
+
38
+ # Common preprocessing (adjust if any model requires different input specs)
39
  preprocess = transforms.Compose([
40
  transforms.Resize(256),
41
  transforms.CenterCrop(224),
 
44
  std=[0.229, 0.224, 0.225]),
45
  ])
46
 
 
 
 
47
  def analyze_skin(image: Image.Image):
 
48
  image = image.convert("RGB")
 
49
  input_tensor = preprocess(image)
50
+ input_batch = input_tensor.unsqueeze(0) # add batch dimension
 
51
 
52
+ results = {}
53
  with torch.no_grad():
54
+ for idx, config in enumerate(MODEL_CONFIGS):
55
+ model, labels, key = MODELS[idx], config["labels"], config["key"]
56
+ outputs = model(input_batch)
57
+ logits = outputs.logits
58
+ probs = torch.softmax(logits, dim=1)
59
+ confidence, pred_idx = torch.max(probs, dim=1)
60
+ predicted_label = labels[pred_idx.item()]
61
+ confidence_score = confidence.item()
62
+ results[key] = {
63
+ "label": predicted_label,
64
+ "confidence": f"{confidence_score:.2%}"
65
+ }
66
+ return results
67
 
 
68
  iface = gr.Interface(
69
  fn=analyze_skin,
70
  inputs=gr.Image(type="pil"),
71
+ outputs=gr.JSON(label="Skin Analysis Results"),
72
+ title="Comprehensive Skin Condition Analyzer",
73
+ description="Classifies skin image for oily/dry, acne, redness, wrinkles using multiple models."
74
  )
75
 
76
  if __name__ == "__main__":