File size: 5,456 Bytes
7727729
 
c061347
7727729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2822729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7727729
531eb3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7727729
 
 
 
 
 
 
 
2822729
7727729
531eb3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7727729
 
 
 
 
 
 
 
 
 
 
 
 
531eb3d
7727729
531eb3d
 
7727729
 
 
531eb3d
 
 
 
 
 
 
 
 
 
7727729
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from ultralytics import YOLO
import streamlit as st
from PIL import Image
import config

@st.cache_resource
def load_model(model_path):
    """
    Loads a YOLO object detection model from the specified model_path.
    Parameters:
        model_path (str): The path to the YOLO model file.
    Returns:
        A YOLO object detection model.
    """
    model = YOLO(model_path)
    return model

# Updated Mapping Template for skin conditions
CLASS_NAMES = {
    0: "Acne",
    1: "Pimples",
    2: "Acne Scars",
    3: "Blackhead",
    4: "Cystic",
    5: "Flat Wart",
    6: "Folliculitis",
    7: "Keloid",
    8: "Milium",
    9: "Papular",
    10: "Purulent",
    11: "Sebo-Crystan-Conglo",
    12: "Whitehead"
}

# Product recommendations mapping
PRODUCT_RECOMMENDATIONS = {
    "Acne": "Salicylic acid cleanser, Non-comedogenic moisturizer",
    "Pimples": "Benzoyl peroxide spot treatment, Oil-free sunscreen",
    "Acne Scars": "Vitamin C serum, Hyaluronic acid",
    "Blackhead": "Charcoal mask, Pore strips",
    "Cystic": "Tea tree oil, Spot patches",
    "Flat Wart": "Over-the-counter salicylic acid",
    "Folliculitis": "Antibacterial wash, Topical cream",
    "Keloid": "Silicone-based scar sheets",
    "Milium": "Retinol cream, Exfoliating cleanser",
    "Papular": "Witch hazel toner, Niacinamide serum",
    "Purulent": "Antiseptic cream, Medicated bandages",
    "Sebo-Crystan-Conglo": "Clay mask, Oil-control moisturizer",
    "Whitehead": "Gentle exfoliating scrub, AHA/BHA toner"
}

# Treatment recommendations mapping
TREATMENT_RECOMMENDATIONS = {
    "Acne": "Regular exfoliation, Avoiding heavy makeup",
    "Pimples": "Gentle cleansing routine, Regular hydration",
    "Acne Scars": "Microneedling, Chemical peels",
    "Blackhead": "Manual extraction by a professional, Laser therapy",
    "Cystic": "Corticosteroid injections, Oral antibiotics",
    "Flat Wart": "Cryotherapy, Electrosurgery",
    "Folliculitis": "Warm compresses, Antibiotic therapy",
    "Keloid": "Corticosteroid injections, Laser treatment",
    "Milium": "Professional extraction, Topical retinoids",
    "Papular": "Blue light therapy, Topical treatments",
    "Purulent": "Incision and drainage, Oral antibiotics",
    "Sebo-Crystan-Conglo": "Isotretinoin therapy, Photodynamic therapy",
    "Whitehead": "Steam and extraction, Preventative skincare routine"
}

def count_objects(boxes):
    counts = {}
    for box in boxes:
        obj_class_index = int(box.cls.item())
        obj_class_name = CLASS_NAMES.get(obj_class_index, f"Class {obj_class_index}")
        counts[obj_class_name] = counts.get(obj_class_name, 0) + 1
    return counts


def display_object_counts(counts, col):
    # Start a container for the cards
    with col.container():
        for obj_class_name, count in counts.items():
            if count > 0:  # Only display if the condition was detected
                # Each card will be in its own column
                card_col1, card_col2 = st.columns([1, 2])
                with card_col1:
                    # Use markdown with HTML to create the card look
                    st.markdown(f"""
                        <div style="background-color: #262730; border-radius: 10px; padding: 20px; margin: 10px 0; box-shadow: 0 2px 4px #484a55;">
                            <h4 style="color: #fafafa; margin: 0;">{obj_class_name}</h4>
                            <h5 style="color: #ff4b4b; margin: 0;">Count: {count}</h5>
                        </div>
                    """, unsafe_allow_html=True)
                with card_col2:
                    st.markdown(f"""
                        <div style="background-color: #262730; border-radius: 10px; padding: 20px; margin: 10px 0; box-shadow: 0 2px 4px #484a55;">
                            <h4 style="color: #ff4b4b; margin: 0;">Recommended Products</h4>
                            <p style="color: #fafafa;">{PRODUCT_RECOMMENDATIONS.get(obj_class_name, 'No products available.')}</p>
                            <h4 style="color: #ff4b4b; margin: 0;">Recommended Treatments</h4>
                            <p style="color: #fafafa;">{TREATMENT_RECOMMENDATIONS.get(obj_class_name, 'No treatments available.')}</p>
                        </div>
                    """, unsafe_allow_html=True)


def infer_uploaded_image(conf, model):
    """
    Execute inference for uploaded image
    :param conf: Confidence of YOLOv8 model
    :param model: An instance of the `YOLOv8` class containing the YOLOv8 model.
    :return: None
    """
    source_img = st.sidebar.file_uploader(
        label="Choose an image...",
        type=("jpg", "jpeg", "png", 'bmp', 'webp')
    )

    col1, col2 = st.columns([1, 2])  # Adjusted for better layout

    if source_img:
        with col1:
            uploaded_image = Image.open(source_img)
            st.image(image=source_img, caption="Uploaded Image", use_column_width=True)

        if st.button("A N A L Y Z E", key="analyze_button"):
            with st.spinner("Running..."):
                res = model.predict(uploaded_image, conf=conf)
                boxes = res[0].boxes
                res_plotted = res[0].plot()[:, :, ::-1]

                with col2:
                    st.image(res_plotted, caption="Analyzed Image", use_column_width=True)
                    object_counts = count_objects(boxes)
                    display_object_counts(object_counts, col2)