3v324v23 commited on
Commit
16031b5
·
1 Parent(s): 1a3069d

Huggingface dataset and model

Browse files
Files changed (1) hide show
  1. pages/Model_Evaluation.py +14 -13
pages/Model_Evaluation.py CHANGED
@@ -15,14 +15,14 @@ from sklearn.preprocessing import label_binarize
15
  import streamlit as st
16
  import matplotlib.pyplot as plt
17
  from fpdf import FPDF
 
 
18
 
19
  # ---- Streamlit State Initialization ----
20
  if 'stop_eval' not in st.session_state:
21
  st.session_state.stop_eval = False
22
-
23
  if 'evaluation_done' not in st.session_state:
24
  st.session_state.evaluation_done = False
25
-
26
  if 'trigger_eval' not in st.session_state:
27
  st.session_state.trigger_eval = False
28
 
@@ -33,7 +33,7 @@ st.markdown("<h2 style='color: #2E86C1;'>📈 Model Evaluation</h2>", unsafe_all
33
  class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
34
  label_map = {label: idx for idx, label in enumerate(class_names)}
35
 
36
- # ---- Text Cleaning Function for PDF ----
37
  def clean_text(text):
38
  return text.encode('utf-8', 'ignore').decode('utf-8')
39
 
@@ -75,7 +75,6 @@ class DDRDataset(Dataset):
75
  image = cv2.imread(img_path)
76
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
77
 
78
- # Apply preprocessing
79
  image = apply_median_filter(image)
80
  image = apply_clahe(image)
81
  image = apply_gamma_correction(image)
@@ -94,34 +93,36 @@ val_transform = transforms.Compose([
94
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
95
  ])
96
 
97
- # ---- Load Data (with caching) ----
98
  @st.cache_resource
99
- def load_test_data(csv_path):
 
 
 
 
100
  dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
101
  return DataLoader(dataset, batch_size=32, shuffle=False)
102
 
103
- # ---- Load Model (with caching) ----
104
  @st.cache_resource
105
  def load_model():
 
106
  model = models.densenet121(pretrained=False)
107
  model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
108
- model.load_state_dict(torch.load(r"D:\\DR_Classification\\training\\Pretrained_Densenet-121.pth", map_location=torch.device('cpu')))
109
  model.eval()
110
  return model
111
 
112
- # ---- Main UI Buttons ----
113
- csv_path = r"D:\\DR_Classification\\splits\\test_labels.csv"
114
  model = load_model()
115
- test_loader = load_test_data(csv_path)
116
 
117
  col1, col2 = st.columns([1, 1])
118
-
119
  with col1:
120
  if st.button("🚀 Start Evaluation"):
121
  st.session_state.stop_eval = False
122
  st.session_state.evaluation_done = False
123
  st.session_state.trigger_eval = True
124
-
125
  with col2:
126
  if st.button("🚩 Stop Evaluation"):
127
  st.session_state.stop_eval = True
 
15
  import streamlit as st
16
  import matplotlib.pyplot as plt
17
  from fpdf import FPDF
18
+ from datasets import load_dataset
19
+ from huggingface_hub import hf_hub_download # ✅ NEW
20
 
21
  # ---- Streamlit State Initialization ----
22
  if 'stop_eval' not in st.session_state:
23
  st.session_state.stop_eval = False
 
24
  if 'evaluation_done' not in st.session_state:
25
  st.session_state.evaluation_done = False
 
26
  if 'trigger_eval' not in st.session_state:
27
  st.session_state.trigger_eval = False
28
 
 
33
  class_names = ['No_DR', 'Mild', 'Moderate', 'Severe', 'Proliferative_DR']
34
  label_map = {label: idx for idx, label in enumerate(class_names)}
35
 
36
+ # ---- Text Cleaning Function ----
37
  def clean_text(text):
38
  return text.encode('utf-8', 'ignore').decode('utf-8')
39
 
 
75
  image = cv2.imread(img_path)
76
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
77
 
 
78
  image = apply_median_filter(image)
79
  image = apply_clahe(image)
80
  image = apply_gamma_correction(image)
 
93
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
94
  ])
95
 
96
+ # ---- Load Data from Hugging Face (cached) ----
97
  @st.cache_resource
98
+ def load_test_data_from_huggingface():
99
+ dataset = load_dataset("Ci-Dave/DDR_dataset_train_test/splits", data_files="test_labels.csv", split="test")
100
+ df = dataset.to_pandas()
101
+ csv_path = "test_labels_temp.csv"
102
+ df.to_csv(csv_path, index=False)
103
  dataset = DDRDataset(csv_path=csv_path, transform=val_transform)
104
  return DataLoader(dataset, batch_size=32, shuffle=False)
105
 
106
+ # ---- Load Model from Hugging Face (cached) ----
107
  @st.cache_resource
108
  def load_model():
109
+ model_path = hf_hub_download(repo_id="Ci-Dave/Densenet121", filename="Pretrained_Densenet-121.pth")
110
  model = models.densenet121(pretrained=False)
111
  model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
112
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
113
  model.eval()
114
  return model
115
 
116
+ # ---- UI Buttons ----
 
117
  model = load_model()
118
+ test_loader = load_test_data_from_huggingface()
119
 
120
  col1, col2 = st.columns([1, 1])
 
121
  with col1:
122
  if st.button("🚀 Start Evaluation"):
123
  st.session_state.stop_eval = False
124
  st.session_state.evaluation_done = False
125
  st.session_state.trigger_eval = True
 
126
  with col2:
127
  if st.button("🚩 Stop Evaluation"):
128
  st.session_state.stop_eval = True