HilmiZr's picture
added: recommendations
531eb3d
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from ultralytics import YOLO
import streamlit as st
from PIL import Image
import config
@st.cache_resource
def load_model(model_path):
"""
Loads a YOLO object detection model from the specified model_path.
Parameters:
model_path (str): The path to the YOLO model file.
Returns:
A YOLO object detection model.
"""
model = YOLO(model_path)
return model
# Updated Mapping Template for skin conditions
CLASS_NAMES = {
0: "Acne",
1: "Pimples",
2: "Acne Scars",
3: "Blackhead",
4: "Cystic",
5: "Flat Wart",
6: "Folliculitis",
7: "Keloid",
8: "Milium",
9: "Papular",
10: "Purulent",
11: "Sebo-Crystan-Conglo",
12: "Whitehead"
}
# Product recommendations mapping
PRODUCT_RECOMMENDATIONS = {
"Acne": "Salicylic acid cleanser, Non-comedogenic moisturizer",
"Pimples": "Benzoyl peroxide spot treatment, Oil-free sunscreen",
"Acne Scars": "Vitamin C serum, Hyaluronic acid",
"Blackhead": "Charcoal mask, Pore strips",
"Cystic": "Tea tree oil, Spot patches",
"Flat Wart": "Over-the-counter salicylic acid",
"Folliculitis": "Antibacterial wash, Topical cream",
"Keloid": "Silicone-based scar sheets",
"Milium": "Retinol cream, Exfoliating cleanser",
"Papular": "Witch hazel toner, Niacinamide serum",
"Purulent": "Antiseptic cream, Medicated bandages",
"Sebo-Crystan-Conglo": "Clay mask, Oil-control moisturizer",
"Whitehead": "Gentle exfoliating scrub, AHA/BHA toner"
}
# Treatment recommendations mapping
TREATMENT_RECOMMENDATIONS = {
"Acne": "Regular exfoliation, Avoiding heavy makeup",
"Pimples": "Gentle cleansing routine, Regular hydration",
"Acne Scars": "Microneedling, Chemical peels",
"Blackhead": "Manual extraction by a professional, Laser therapy",
"Cystic": "Corticosteroid injections, Oral antibiotics",
"Flat Wart": "Cryotherapy, Electrosurgery",
"Folliculitis": "Warm compresses, Antibiotic therapy",
"Keloid": "Corticosteroid injections, Laser treatment",
"Milium": "Professional extraction, Topical retinoids",
"Papular": "Blue light therapy, Topical treatments",
"Purulent": "Incision and drainage, Oral antibiotics",
"Sebo-Crystan-Conglo": "Isotretinoin therapy, Photodynamic therapy",
"Whitehead": "Steam and extraction, Preventative skincare routine"
}
def count_objects(boxes):
counts = {}
for box in boxes:
obj_class_index = int(box.cls.item())
obj_class_name = CLASS_NAMES.get(obj_class_index, f"Class {obj_class_index}")
counts[obj_class_name] = counts.get(obj_class_name, 0) + 1
return counts
def display_object_counts(counts, col):
# Start a container for the cards
with col.container():
for obj_class_name, count in counts.items():
if count > 0: # Only display if the condition was detected
# Each card will be in its own column
card_col1, card_col2 = st.columns([1, 2])
with card_col1:
# Use markdown with HTML to create the card look
st.markdown(f"""
<div style="background-color: #262730; border-radius: 10px; padding: 20px; margin: 10px 0; box-shadow: 0 2px 4px #484a55;">
<h4 style="color: #fafafa; margin: 0;">{obj_class_name}</h4>
<h5 style="color: #ff4b4b; margin: 0;">Count: {count}</h5>
</div>
""", unsafe_allow_html=True)
with card_col2:
st.markdown(f"""
<div style="background-color: #262730; border-radius: 10px; padding: 20px; margin: 10px 0; box-shadow: 0 2px 4px #484a55;">
<h4 style="color: #ff4b4b; margin: 0;">Recommended Products</h4>
<p style="color: #fafafa;">{PRODUCT_RECOMMENDATIONS.get(obj_class_name, 'No products available.')}</p>
<h4 style="color: #ff4b4b; margin: 0;">Recommended Treatments</h4>
<p style="color: #fafafa;">{TREATMENT_RECOMMENDATIONS.get(obj_class_name, 'No treatments available.')}</p>
</div>
""", unsafe_allow_html=True)
def infer_uploaded_image(conf, model):
"""
Execute inference for uploaded image
:param conf: Confidence of YOLOv8 model
:param model: An instance of the `YOLOv8` class containing the YOLOv8 model.
:return: None
"""
source_img = st.sidebar.file_uploader(
label="Choose an image...",
type=("jpg", "jpeg", "png", 'bmp', 'webp')
)
col1, col2 = st.columns([1, 2]) # Adjusted for better layout
if source_img:
with col1:
uploaded_image = Image.open(source_img)
st.image(image=source_img, caption="Uploaded Image", use_column_width=True)
if st.button("A N A L Y Z E", key="analyze_button"):
with st.spinner("Running..."):
res = model.predict(uploaded_image, conf=conf)
boxes = res[0].boxes
res_plotted = res[0].plot()[:, :, ::-1]
with col2:
st.image(res_plotted, caption="Analyzed Image", use_column_width=True)
object_counts = count_objects(boxes)
display_object_counts(object_counts, col2)