dental-analysis / app.py
Tachii's picture
Update app.py
7a70d0b verified
raw
history blame
7.82 kB
# Updated Conditions Tab with Complete Outlining, Default Colors, and Sagittal Cuts for Deep Caries
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import pandas as pd
st.set_page_config(
page_title="DENTBOOK: AI-Powered Dental X-Ray Analysis",
page_icon="🦷",
layout="wide",
initial_sidebar_state="expanded"
)
@st.cache_resource
def load_models():
try:
return (
YOLO('best.pt'),
YOLO('disease_model_weights.pt'),
YOLO('enumerate_model_weights.pt')
)
except Exception as e:
st.error(f"Model loading error: {e}")
return None, None, None
def group_predictions(results):
grouped = {}
for box in results.boxes:
x1,y1,x2,y2 = map(int, box.xyxy[0])
cls = results.names[int(box.cls[0])]
conf = float(box.conf[0])
grouped.setdefault(cls, []).append({'coords':(x1,y1,x2,y2),'confidence':conf})
return grouped
def crop_detection(image, box_coords, scale=0.4):
img_array = np.array(image)
x1,y1,x2,y2 = box_coords
pad_x, pad_y = int((x2-x1)*0.1), int((y2-y1)*0.1)
h,w = img_array.shape[:2]
x1,y1 = max(0,x1-pad_x), max(0,y1-pad_y)
x2,y2 = min(w,x2+pad_x), min(h,y2+pad_y)
cropped = img_array[y1:y2,x1:x2]
return Image.fromarray(cv2.resize(cropped,(0,0),fx=scale,fy=scale))
COLORS={
"Caries":(255,0,0),"Decay":(255,165,0),"Lesion":(255,105,180),
"Crown":(0,255,0),"Filling":(0,0,255),"Implant":(255,215,0),
"Impacted":(128,0,128),"Deep Caries":(220,20,60)
}
def get_color(cls_name):
return COLORS.get(cls_name,(255,0,0))
def highlight_and_fill_tooth(image,box_coords,cls_name,thickness=3,alpha=0.3):
img,overlay=np.array(image).copy(),np.array(image).copy()
color=get_color(cls_name)
x1,y1,x2,y2=box_coords
cv2.rectangle(overlay,(x1,y1),(x2,y2),color,-1)
cv2.addWeighted(overlay,alpha,img,1-alpha,0,img)
cv2.rectangle(img,(x1,y1),(x2,y2),color,thickness)
return Image.fromarray(img)
def draw_sagittal_cut_mm(image,pixel_to_mm_ratio=0.1):
img=np.array(image).copy()
h,w=img.shape[:2]
cx=w//2
top_y,bottom_y=int(h*0.2),int(h*0.9)
vert_len_px=bottom_y-top_y
horiz_y=top_y+int(vert_len_px*0.75)
horiz_sx,horiz_ex=cx-int(w*0.15),cx+int(w*0.15)
vert_mm,horiz_mm=vert_len_px*pixel_to_mm_ratio,(horiz_ex-horiz_sx)*pixel_to_mm_ratio
cv2.line(img,(cx,top_y),(cx,bottom_y),(0,255,0),1)
cv2.line(img,(horiz_sx,horiz_y),(horiz_ex,horiz_y),(0,255,255),1)
cv2.putText(img,f"{vert_mm:.1f}mm",(cx+5,bottom_y-5),cv2.FONT_HERSHEY_SIMPLEX,0.4,(0,255,0),1)
cv2.putText(img,f"{horiz_mm:.1f}mm",(horiz_sx,horiz_y-5),cv2.FONT_HERSHEY_SIMPLEX,0.4,(0,255,255),1)
return Image.fromarray(img)
def create_confidence_chart(data):
df=pd.DataFrame([{"Condition":c,"Confidence":d["confidence"]}for c,ds in data.items()for d in ds])
return px.box(df,x='Condition',y='Confidence',points="all",title='Confidence Distribution')
def create_condition_count_chart(data):
counts={c:len(d)for c,d in data.items()}
return go.Figure(go.Pie(labels=list(counts),values=list(counts.values()))).update_layout(title='Condition Distribution')
EXAMPLES={
"Example 1":"dental3.png","Example 2":"dental2.jpeg","Example 3":"dental1.jpg","Example 4":"dental4.jpg",
}
def main():
cond_model,path_model,tooth_model=load_models()
if not all([cond_model,path_model,tooth_model]):st.stop()
with st.sidebar:
st.title("About")
st.info(
"""
DentBook AI – Next-Level Dental X-ray Diagnostics
DentBook AI leverages cutting-edge artificial intelligence to provide comprehensive analysis of dental X-rays, identifying and classifying a wide range of dental conditions and anatomical features.
🦷 **Common Dental Issues**
- Detection of Tooth Decay and Caries
- Identification of Missing and Cracked Teeth
- Differentiation Between Primary and Permanent Teeth
- Assessment of Tooth Wear and Attrition
👨‍⚕️ **Restorations and Dental Procedures**
- Identification of Crowns and Fillings
- Detection of Dental Implants and Abutments
- Evaluation of Root Canal Treatments
- Recognition of Post and Core Restorations and Gingival Formers
🎯 **Orthodontic Appliances and Features**
- Detection of Tooth Misalignment
- Localization of Brackets and Orthodontic Wires
- Identification of Fixed Retainers
- Detection of Temporary Anchorage Devices (TADs) and Metal Bands
🔍 **Bone Structure and Soft Tissue Evaluation**
- Analysis of the Mandibular Canal
- Assessment of Maxillary Sinus Health
- Identification of Bone Deficiencies and Defects
- Detection of Cysts and Related Pathologies
⚠️ **Complex and Specialized Conditions**
- Detection of Impacted Teeth
- Identification of Periapical Pathologies
- Recognition of Retained Roots and Fragments
- Evaluation of Root Resorption and Supraeruption
"""
)
st.title("🦷 DENTBOOK: AI-Powered Dental X-Ray Analysis")
uploaded=st.file_uploader("Upload X-ray",['png','jpg','jpeg'])
example=st.selectbox("Or select example",["None"]+list(EXAMPLES))
img=Image.open(EXAMPLES[example])if example!="None" else None
if uploaded:img=Image.open(uploaded)
if img:st.image(img,use_container_width=True)
if img and st.button("🔍 Analyze with AI"):
with st.spinner("Analyzing..."):
conditions=group_predictions(cond_model.predict(img)[0])
pathologies=group_predictions(path_model.predict(img)[0])
numbering=tooth_model.predict(img)[0]
tabs=st.tabs(["🦷 Pathologies","⚙️ Conditions","🔢 Tooth Numbering"])
for title,dataset,tab in[("Pathologies",pathologies,tabs[0]),("Conditions",conditions,tabs[1])]:
with tab:
st.header("Advanced Visualizations")
viz1,viz2=st.columns(2)
viz1.plotly_chart(create_confidence_chart(dataset))
viz2.plotly_chart(create_condition_count_chart(dataset))
for idx,(c,dets)in enumerate(dataset.items(),1):
st.subheader(f"{idx}. {c} ({len(dets)})")
for i,det in enumerate(dets,1):
st.markdown(f"**{i}. {c}:** {det['confidence']:.2%}")
cols=st.columns([2,1])
cols[0].image(highlight_and_fill_tooth(img,det['coords'],c),caption="Full View",use_container_width=True)
cropped=crop_detection(img,det['coords'])
if c.lower() in ["caries","deep caries"]:
cropped=draw_sagittal_cut_mm(cropped)
caption="Close-up & Sagittal"
else:
caption="Close-up"
cols[1].image(cropped,caption=caption,use_container_width=True)
st.divider()
with tabs[2]:
st.image(numbering.plot(line_width=1),caption="Tooth Numbering",use_container_width=True)
if __name__=="__main__":
main()