# -*- coding: utf-8 -*- import streamlit as st import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models, transforms, datasets from torch.utils.data import DataLoader, ConcatDataset from PIL import Image, ImageOps # Added ImageOps import os # import tempfile # Not used for primary storage here import random import shutil import matplotlib.pyplot as plt # from sklearn.decomposition import PCA # PCA not used, removed import numpy as np import cv2 # from PIL import Image # Already imported from ultralytics import YOLO import pandas as pd import json # Added for saving/loading metadata # --- README.md Configuration Reminder --- # Make sure your README.md contains at least: # --- # hardware: your_gpu_id_here # e.g., t4-small (Required for persistent storage) # storage: # mount_point: /data # --- # And other keys like sdk, app_file. # ---------------------------------------- st.set_page_config(layout="wide") # --- Constants and Setup --- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Path Configuration --- # Get the directory where this script is located APP_DIR = os.path.dirname(os.path.abspath(__file__)) # --- Persistent Storage Mount Point --- PERSISTENT_STORAGE_MOUNT_POINT = "/data" # Standard mount point # Define paths relative to the script directory or persistent storage BASE_DATASET = os.path.join(APP_DIR, "App data/data_optimize/Basee_data") # From Git Repo MODEL_WEIGHTS_PATH = os.path.join(APP_DIR, "model/efficientnet_coffee (1).pth") # From Git Repo YOLO_MODEL_PATH = os.path.join(APP_DIR, "model/best.pt") # From Git Repo # *** MODIFICATION START: Point dynamic data to persistent storage *** RARE_DATASET = os.path.join(PERSISTENT_STORAGE_MOUNT_POINT, "Rare_data") # On Persistent Volume SAVED_MODELS_DIR = os.path.join(PERSISTENT_STORAGE_MOUNT_POINT, "saved_few_shot_models") # On Persistent Volume # *** MODIFICATION END *** # --- Ensure Persistent Directories Exist --- # *** MODIFICATION START: Create persistent dirs if needed *** try: os.makedirs(RARE_DATASET, exist_ok=True) os.makedirs(SAVED_MODELS_DIR, exist_ok=True) # Optional: Show effective paths being used # st.sidebar.info(f"Rare Data Path: {RARE_DATASET}") # st.sidebar.info(f"Saved Models Path: {SAVED_MODELS_DIR}") except OSError as e: st.error(f"Fatal: Error creating directories in persistent storage ('{PERSISTENT_STORAGE_MOUNT_POINT}'). Check Space config/permissions. Error: {e}") st.stop() # Stop execution if persistent storage dirs can't be created except Exception as e: st.error(f"Fatal: An unexpected error occurred accessing persistent storage: {e}") st.stop() # *** MODIFICATION END *** # Check required files/folders from Git Repo exist (using paths relative to APP_DIR) if not os.path.isdir(BASE_DATASET): st.error(f"Base dataset directory not found: {BASE_DATASET}") st.stop() if not os.path.isfile(MODEL_WEIGHTS_PATH): st.error(f"Classifier weights file not found: {MODEL_WEIGHTS_PATH}") st.stop() if not os.path.isfile(YOLO_MODEL_PATH): st.error(f"YOLO detection model file not found: {YOLO_MODEL_PATH}") st.stop() st.sidebar.info(f"Using device: {DEVICE}") # --- Helper Functions for Saving/Loading Few-Shot States --- # (Original Code - relies on SAVED_MODELS_DIR which now points to /data/...) def list_saved_models(): """Returns a list of names of saved few-shot model states.""" if not os.path.isdir(SAVED_MODELS_DIR): # st.warning(f"Directory not found: {SAVED_MODELS_DIR}") # Less verbose return [] try: return [d for d in os.listdir(SAVED_MODELS_DIR) if os.path.isdir(os.path.join(SAVED_MODELS_DIR, d))] except Exception as e: st.error(f"Error listing {SAVED_MODELS_DIR}: {e}") return [] def save_few_shot_state(name, model, prototypes, proto_labels, current_class_names, few_shot_strategy): """Saves the model state, prototypes, strategy, and metadata.""" if not name or not name.strip(): st.error("Please provide a valid name for the saved model.") return False sanitized_name = "".join(c for c in name if c.isalnum() or c in ('_', '-')).rstrip() if not sanitized_name: st.error("Invalid name after sanitization. Use letters, numbers, underscore, or hyphen.") return False save_dir = os.path.join(SAVED_MODELS_DIR, sanitized_name) # Path uses /data/... proceed_with_save = True if os.path.exists(save_dir): col1, col2 = st.columns([3, 1]) with col1: st.warning(f"Model name '{sanitized_name}' already exists.") with col2: overwrite_key = f"overwrite_button_{sanitized_name}" if not st.button("Overwrite?", key=overwrite_key): st.info("Save cancelled. Choose a different name or click 'Overwrite?'.") proceed_with_save = False else: st.info(f"Overwriting '{sanitized_name}'...") else: st.info(f"Saving new model state '{sanitized_name}'...") if not proceed_with_save: return False try: os.makedirs(save_dir, exist_ok=True) # Creates dir in /data/... # 1. Save Model State Dictionary model.to('cpu') model_path = os.path.join(save_dir, "feature_extractor_state_dict.pth") torch.save(model.state_dict(), model_path) model.to(DEVICE) # 2. Save Prototypes Tensor prototypes_path = os.path.join(save_dir, "prototypes.pt") torch.save(prototypes.cpu(), prototypes_path) # 3. Save Metadata if few_shot_strategy != 'train_projection': st.error(f"Internal Error: Attempting to save with invalid strategy '{few_shot_strategy}'. Expected 'train_projection'. Save cancelled.") if os.path.exists(save_dir): shutil.rmtree(save_dir) return False metadata = { "prototype_labels": proto_labels, "class_names_on_save": current_class_names, "few_shot_strategy": 'train_projection' } metadata_path = os.path.join(save_dir, "metadata.json") with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=4) st.success(f"Few-shot model state '{sanitized_name}' saved successfully!") return True except Exception as e: st.error(f"Error saving model state '{sanitized_name}' to {save_dir}: {e}") st.exception(e) if os.path.exists(save_dir): try: shutil.rmtree(save_dir) st.info(f"Cleaned up partially saved directory '{save_dir}'.") except Exception as cleanup_e: st.error(f"Error cleaning up directory during save failure: {cleanup_e}") return False def load_few_shot_state(name, model_to_load_into, current_class_names): """Loads a saved model state, prototypes, labels, and strategy into session state and the model.""" load_dir = os.path.join(SAVED_MODELS_DIR, name) # Path uses /data/... if not os.path.isdir(load_dir): st.error(f"Saved model directory '{load_dir}' not found.") return False model_path = os.path.join(load_dir, "feature_extractor_state_dict.pth") prototypes_path = os.path.join(load_dir, "prototypes.pt") metadata_path = os.path.join(load_dir, "metadata.json") if not all(os.path.exists(p) for p in [model_path, prototypes_path, metadata_path]): st.error(f"Saved model '{name}' is incomplete. Files missing in '{load_dir}'.") return False try: # 1. Load Metadata with open(metadata_path, 'r') as f: metadata = json.load(f) loaded_proto_labels = metadata.get("prototype_labels") saved_class_names = metadata.get("class_names_on_save") loaded_strategy = metadata.get("few_shot_strategy") if loaded_proto_labels is None or saved_class_names is None or loaded_strategy is None: st.error(f"Metadata file for '{name}' is corrupted or missing required keys (labels, class_names, strategy).") return False if loaded_strategy != 'train_projection': st.error(f"Saved model '{name}' used strategy '{loaded_strategy}', but only 'train_projection' (frozen backbone) is currently supported. Cannot load.") return False if set(saved_class_names) != set(current_class_names): st.warning(f"⚠️ **Class Mismatch!**") st.warning(f"Saved model '{name}' classes: `{saved_class_names}`") st.warning(f"Current active classes: `{current_class_names}`") st.warning("Predictions might be incorrect or errors may occur. Proceed with caution.") # 2. Load Model State Dictionary model_to_load_into.to(DEVICE) state_dict = torch.load(model_path, map_location=DEVICE) try: missing_keys, unexpected_keys = model_to_load_into.load_state_dict(state_dict, strict=True) # Keep strict=True if missing_keys: st.warning(f"Loaded state dict is missing keys: {missing_keys}") if unexpected_keys: st.warning(f"Loaded state dict has unexpected keys: {unexpected_keys}") except RuntimeError as e: st.error(f"RuntimeError loading state_dict for '{name}'. Architecture mismatch? {e}") st.error("This usually means the saved model structure (base + projection) doesn't match the current code's structure.") return False model_to_load_into.eval() # 3. Load Prototypes loaded_prototypes = torch.load(prototypes_path, map_location=DEVICE) # 4. Update Session State st.session_state.final_prototypes = loaded_prototypes st.session_state.prototype_labels = loaded_proto_labels st.session_state.few_shot_strategy = loaded_strategy st.session_state.few_shot_trained = True st.session_state.model_mode = 'few_shot' st.success(f"Successfully loaded few-shot model state '{name}' (Strategy: {loaded_strategy}). Mode set to Few-Shot.") return True except Exception as e: st.error(f"Error loading model state '{name}' from {load_dir}: {e}") st.exception(e) # Reset state if loading fails partially st.session_state.final_prototypes = None st.session_state.prototype_labels = None st.session_state.few_shot_strategy = None st.session_state.few_shot_trained = False st.session_state.model_mode = 'standard' return False def delete_saved_model(name): """Deletes a saved model directory.""" delete_dir = os.path.join(SAVED_MODELS_DIR, name) # Path uses /data/... if not os.path.isdir(delete_dir): st.error(f"Cannot delete. Saved model '{name}' not found in {SAVED_MODELS_DIR}.") # Updated path in msg return False try: shutil.rmtree(delete_dir) st.success(f"Deleted saved model '{name}'.") return True except Exception as e: st.error(f"Error deleting saved model '{name}' from {delete_dir}: {e}") # Updated path in msg return False # --- Model Architectures --- # (Original Code) class EfficientNetWithProjection(nn.Module): def __init__(self, base_model, output_dim=1024): super(EfficientNetWithProjection, self).__init__() self.model = base_model in_features = 1280 self.projection = nn.Linear(in_features, output_dim) def forward(self, x): features = self.model(x) projected_features = self.projection(features) return projected_features def get_base_efficientnet_architecture(num_classes=5): model = models.efficientnet_b0(weights=None) in_features = model.classifier[1].in_features model.classifier[1] = nn.Linear(in_features, num_classes) return model def get_feature_extractor_base(): base_model = get_base_efficientnet_architecture(num_classes=5) try: # Load from path relative to script dir (Git repo) state_dict = torch.load(MODEL_WEIGHTS_PATH, map_location=DEVICE) missing_keys, unexpected_keys = base_model.load_state_dict(state_dict, strict=False) if unexpected_keys and not all(k.startswith('classifier.') for k in unexpected_keys): st.warning(f"Loading base weights: Unexpected keys found beyond classifier: {unexpected_keys}") if missing_keys: st.warning(f"Loading base weights: Missing keys: {missing_keys}") except Exception as e: st.error(f"Error loading model weights from {MODEL_WEIGHTS_PATH} into base architecture: {e}") st.exception(e) st.stop() base_model.classifier = nn.Identity() base_model.eval() return base_model def load_standard_classifier(): model = get_base_efficientnet_architecture(num_classes=5) try: # Load from path relative to script dir (Git repo) state_dict = torch.load(MODEL_WEIGHTS_PATH, map_location=DEVICE) model.load_state_dict(state_dict, strict=True) except Exception as e: st.error(f"Error loading model weights for standard classifier: {e}") st.exception(e) st.stop() model.to(DEVICE) model.eval() return model # --- Caching --- # (Original Code) @st.cache_resource def cached_feature_extractor_model(): base_model = get_feature_extractor_base() model = EfficientNetWithProjection(base_model, output_dim=1024) model.to(DEVICE) model.eval() st.sidebar.info("Feature extractor model ready (cached).") return model @st.cache_resource def cached_standard_classifier(): model = load_standard_classifier() st.sidebar.info("Standard classifier model ready (cached).") return model # --- Data Loading --- # (Original Code - uses base_path from Git, rare_path from /data/...) @st.cache_data def get_combined_dataset_and_indices(base_path, rare_path): """Loads base data from Git repo path and rare data from persistent storage path.""" try: transform_local = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load Base dataset from Git repo path if not os.path.isdir(base_path): st.error(f"Base dataset path not found: {base_path}") st.stop() full_dataset = datasets.ImageFolder(base_path, transform=transform_local) num_base_classes = len(full_dataset.classes) base_class_names = sorted(full_dataset.classes) rare_classes_found = 0 rare_class_names = [] combined_dataset = full_dataset # Start with base dataset # Try loading Rare dataset from persistent storage path (rare_path = /data/Rare_data) if os.path.isdir(rare_path) and any(os.scandir(rare_path)): try: rare_dataset = datasets.ImageFolder(rare_path, transform=transform_local) if len(rare_dataset.samples) > 0: rare_dataset.samples = [(path, label + num_base_classes) for path, label in rare_dataset.samples] combined_dataset = ConcatDataset([full_dataset, rare_dataset]) rare_classes_found = len(rare_dataset.classes) rare_class_names = sorted(rare_dataset.classes) else: st.info(f"Rare dataset directory '{rare_path}' exists but is empty.") except Exception as e_rare: st.warning(f"Could not load rare dataset from {rare_path}: {e_rare}. Using base dataset only.") else: st.info(f"Rare dataset directory '{rare_path}' not found or empty. Using base dataset only.") # --- Index calculation (Original logic) --- indices = {} current_idx = 0 if isinstance(combined_dataset, ConcatDataset): for ds in combined_dataset.datasets: if hasattr(ds, 'samples'): for _, label in ds.samples: indices.setdefault(label, []).append(current_idx) current_idx += 1 else: for i in range(len(ds)): _, label = ds[i] indices.setdefault(label, []).append(current_idx) current_idx += 1 elif isinstance(combined_dataset, datasets.ImageFolder): for idx, (_, label) in enumerate(combined_dataset.samples): indices.setdefault(label, []).append(idx) else: st.error("Unexpected dataset type encountered when building indices.") st.stop() class_names = base_class_names + rare_class_names st.sidebar.metric("Base Classes (Git)", num_base_classes) st.sidebar.metric("Rare Classes (Storage)", rare_classes_found) st.sidebar.metric("Total Classes", len(class_names)) if len(class_names) == 0: st.error("No classes found in base or rare datasets. Check paths/contents.") st.stop() return combined_dataset, indices, class_names, num_base_classes except FileNotFoundError as e: st.error(f"Dataset path error: {e}. Check BASE_DATASET ('{base_path}') and RARE_DATASET ('{rare_path}').") st.stop() except Exception as e: st.error(f"Error loading datasets: {e}") st.exception(e) st.stop() # --- Global transform --- # (Original Code) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # --- Few-Shot Learning Functions --- # (Original Code) def create_episode(dataset, class_indices, class_list, n_way=5, n_shot=5, n_query=5): """Creates an episode for Prototypical Networks.""" available_classes = list(class_indices.keys()) if len(available_classes) < n_way: n_way = len(available_classes) if n_way < 2: st.error(f"Episode creation failed: Need at least 2 classes with enough samples, found {n_way}.") return None, None, None, None eligible_classes = [ cls_id for cls_id in available_classes if len(class_indices.get(cls_id, [])) >= (n_shot + n_query) ] if len(eligible_classes) < n_way: n_way = len(eligible_classes) if n_way < 2: st.error(f"Episode creation failed: Need at least 2 eligible classes (with {n_shot+n_query} samples each). Found {n_way}.") return None, None, None, None selected_class_ids = random.sample(eligible_classes, n_way) support_imgs, query_imgs = [], [] support_labels, query_labels = [], [] episode_class_map = {original_label: episode_label for episode_label, original_label in enumerate(selected_class_ids)} for original_label in selected_class_ids: indices_for_class = class_indices.get(original_label, []) sampled_indices = random.sample(indices_for_class, n_shot + n_query) try: items_getter = lambda dataset_obj, index: dataset_obj[index] support_imgs += [items_getter(dataset, i)[0] for i in sampled_indices[:n_shot]] query_imgs += [items_getter(dataset, i)[0] for i in sampled_indices[n_shot:]] except IndexError as e: st.error(f"IndexError during episode creation for class {original_label}. Sampled: {sampled_indices}. Dataset len: {len(dataset)}.") st.exception(e) return None, None, None, None except Exception as e: st.error(f"Error retrieving data during episode creation: {e}") st.exception(e) return None, None, None, None new_label = episode_class_map[original_label] support_labels += [new_label] * n_shot query_labels += [new_label] * n_query try: s_imgs_tensor = torch.stack(support_imgs).to(DEVICE) s_labels_tensor = torch.tensor(support_labels, dtype=torch.long).to(DEVICE) q_imgs_tensor = torch.stack(query_imgs).to(DEVICE) q_labels_tensor = torch.tensor(query_labels, dtype=torch.long).to(DEVICE) return s_imgs_tensor, s_labels_tensor, q_imgs_tensor, q_labels_tensor except Exception as e: st.error(f"Error stacking tensors in create_episode: {e}") st.exception(e) return None, None, None, None # (Original Code - may have issues, see previous discussions if needed) def proto_loss(support_embeddings, support_labels, query_embeddings, query_labels): """Calculates the Prototypical Network loss and accuracy.""" if support_embeddings is None or support_embeddings.numel() == 0 or \ query_embeddings is None or query_embeddings.numel() == 0: return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 unique_episode_labels = torch.unique(support_labels) n_way_actual = len(unique_episode_labels) if n_way_actual < 2: return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 prototypes = [] # Original loop might have issues if a label has no samples - less robust for episode_label in range(n_way_actual): # Assumes labels are 0 to n_way_actual-1 class_mask = (support_labels == episode_label) if torch.any(class_mask): class_embeddings = support_embeddings[class_mask] prototypes.append(class_embeddings.mean(dim=0)) else: # Original code didn't explicitly handle this case well st.warning(f"ProtoLoss (Original): No support embeddings found for episode label {episode_label}. Potential issue.") # Returning 0 here might be problematic if other prototypes exist return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 if len(prototypes) != n_way_actual: st.warning(f"ProtoLoss (Original): Mismatch ways ({n_way_actual}) vs prototypes ({len(prototypes)}).") return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 prototypes = torch.stack(prototypes) # Original code didn't filter query labels based on actual prototypes formed valid_query_mask = torch.isin(query_labels, unique_episode_labels) if not torch.any(valid_query_mask): return torch.tensor(0.0, requires_grad=True).to(DEVICE), 0.0 filtered_query_embeddings = query_embeddings[valid_query_mask] filtered_query_labels = query_labels[valid_query_mask] distances = torch.cdist(filtered_query_embeddings, prototypes) predictions = torch.argmin(distances, dim=1) correct_predictions = (predictions == filtered_query_labels).sum().item() # Original comparison might be offset if labels aren't 0..N-1 total_predictions = filtered_query_labels.size(0) accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0 loss = F.cross_entropy(-distances, filtered_query_labels) # Original label usage might be offset return loss, accuracy # (Original Code) @st.cache_data(show_spinner="Calculating final prototypes for all classes...") def calculate_final_prototypes(_model, _dataset, _class_names, _strategy): if _strategy != 'train_projection': st.warning(f"calculate_final_prototypes called with unexpected strategy: '{_strategy}'. Proceeding as if 'train_projection'.") _model.eval() all_embeddings = {} loader = DataLoader(_dataset, batch_size=128, shuffle=False, num_workers=0, pin_memory=True if DEVICE=='cuda' else False) with torch.no_grad(): for imgs, labs in loader: imgs = imgs.to(DEVICE) try: emb = _model(imgs) emb_cpu = emb.cpu() labs_list = labs.tolist() for i in range(emb_cpu.size(0)): label = labs_list[i] all_embeddings.setdefault(label, []).append(emb_cpu[i]) except Exception as e: st.error(f"Error during embedding calculation batch: {e}") continue final_prototypes = [] prototype_labels = [] unique_labels_present = sorted(list(all_embeddings.keys())) if not unique_labels_present: st.warning("No embeddings were generated. Cannot calculate prototypes.") return None, None for label in unique_labels_present: if not (0 <= label < len(_class_names)): st.warning(f"Skipping label {label} during prototype calculation: Out of bounds for class names list (len={len(_class_names)}).") continue class_embeddings_list = all_embeddings[label] if class_embeddings_list: try: class_embeddings = torch.stack(class_embeddings_list) prototype = class_embeddings.mean(dim=0) final_prototypes.append(prototype) prototype_labels.append(label) except Exception as e: st.error(f"Error processing embeddings for class {label} ('{_class_names[label]}'): {e}") continue if not final_prototypes: st.warning("Could not calculate any valid final prototypes.") return None, None final_prototypes_tensor = torch.stack(final_prototypes).to(DEVICE) st.success(f"Calculated {len(final_prototypes)} final prototypes (Strategy: {_strategy}) for original labels: {prototype_labels}") return final_prototypes_tensor, prototype_labels # --- Object Detection (YOLO) --- # (Original Code) @st.cache_resource(show_spinner="Loading detection model...") def load_yolo_model(): try: model = YOLO(YOLO_MODEL_PATH) return model except Exception as e: st.error(f"Failed to load YOLO detection model from {YOLO_MODEL_PATH}: {e}") st.exception(e) st.stop() def detect_objects(image): model = load_yolo_model() img_array = np.array(image.convert("RGB")) img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) try: results = model(img_bgr, device=DEVICE) except Exception as e: st.error(f"Error during YOLO inference: {e}") return img_array, pd.DataFrame() result_image_bgr = results[0].plot(conf=True, labels=True) result_image_rgb = cv2.cvtColor(result_image_bgr, cv2.COLOR_BGR2RGB) detections_list = [] if results[0].boxes is not None: boxes = results[0].boxes.xyxy.cpu().numpy() confs = results[0].boxes.conf.cpu().numpy() cls_ids = results[0].boxes.cls.cpu().numpy().astype(int) class_names_map = model.names # Use class names from YOLO model for i in range(len(boxes)): detections_list.append({ "Class": class_names_map.get(cls_ids[i], f"ID {cls_ids[i]}"), "Confidence": confs[i], "X_min": boxes[i, 0], "Y_min": boxes[i, 1], "X_max": boxes[i, 2], "Y_max": boxes[i, 3], }) detections_df = pd.DataFrame(detections_list) return result_image_rgb, detections_df # === Main App Logic === # (Original Code) st.title("🌿 Coffee Leaf Disease Classifier + Few-Shot Learning + Detection") # --- Initialize Session State --- st.session_state.setdefault('few_shot_trained', False) st.session_state.setdefault('final_prototypes', None) st.session_state.setdefault('prototype_labels', None) st.session_state.setdefault('model_mode', 'standard') st.session_state.setdefault('few_shot_strategy', None) # --- Load Data --- # Uses BASE_DATASET (Git) and RARE_DATASET (/data/Rare_data) combined_dataset, class_indices, class_names, num_base_classes = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) # --- Sidebar --- st.sidebar.header("⚙️ Options & Status") # --- Mode Selection / Status --- st.sidebar.subheader("Mode") if st.sidebar.button("🔄 Reset to Standard Classifier"): st.session_state.model_mode = 'standard' st.session_state.few_shot_trained = False st.session_state.final_prototypes = None st.session_state.prototype_labels = None st.session_state.few_shot_strategy = None st.success("Switched to Standard Classification Mode.") st.cache_data.clear() st.cache_resource.clear() st.rerun() mode_status = "Standard Classifier" strategy_info = "" # Check prototype_labels for length as well if st.session_state.model_mode == 'few_shot' and \ st.session_state.final_prototypes is not None and \ st.session_state.prototype_labels is not None and \ len(st.session_state.prototype_labels) > 0: mode_status = f"Few-Shot ({len(st.session_state.prototype_labels)} Prototypes)" # Use label length strategy_info = f"(Strategy: {st.session_state.get('few_shot_strategy', 'N/A').replace('_', ' ').title()})" st.sidebar.info(f"**Current Mode:** {mode_status} {strategy_info}") # --- Load/Delete Saved Few-Shot Models --- st.sidebar.divider() st.sidebar.subheader("💾 Saved Few-Shot Models (Persistent)") # Updated title saved_model_names = list_saved_models() # Reads from /data/... # --- Loading Section --- if not saved_model_names: st.sidebar.info("No saved few-shot models found in persistent storage.") # Updated msg else: selected_model_to_load = st.sidebar.selectbox( "Load a saved few-shot state:", options=[""] + saved_model_names, key="load_model_select", index=0 ) if st.sidebar.button("📥 Load Selected State", key="load_model_button", disabled=(not selected_model_to_load)): if selected_model_to_load: model_instance = cached_feature_extractor_model() _, _, current_cls_names_on_load, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) if load_few_shot_state(selected_model_to_load, model_instance, current_cls_names_on_load): # Reads from /data/... st.rerun() # --- Deleting Section --- if saved_model_names: st.sidebar.markdown("---") selected_model_to_delete = st.sidebar.selectbox( "Delete a saved few-shot state:", options=[""] + saved_model_names, key="delete_model_select", index=0 ) if selected_model_to_delete: confirm_delete = st.sidebar.checkbox(f"Confirm deletion of '{selected_model_to_delete}' from persistent storage", key="delete_confirm") # Updated msg if st.sidebar.button("❌ Delete Selected State", key="delete_model_button", disabled=(not confirm_delete)): if confirm_delete: if delete_saved_model(selected_model_to_delete): # Deletes from /data/... st.rerun() # --- Main Panel Options --- option = st.radio( "Choose an action:", ["Upload & Predict", "Add/Manage Rare Classes", "Train Few-Shot Model", "Detection"], horizontal=True, key="main_option" ) # Load models (cached) - Load from Git paths feature_extractor_model = cached_feature_extractor_model() standard_classifier_model = cached_standard_classifier() # --- Action Implementation --- # (Original Code) if option == "Upload & Predict": st.header("🔎 Upload Image for Prediction") uploaded_file = st.file_uploader("Choose a coffee leaf image...", type=["jpg", "jpeg", "png"], key="file_uploader") if uploaded_file: try: image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", width=300) input_tensor = transform(image).unsqueeze(0).to(DEVICE) # Determine mode based on session state (which might be loaded from /data/...) use_few_shot = (st.session_state.model_mode == 'few_shot' and st.session_state.final_prototypes is not None and st.session_state.prototype_labels is not None and st.session_state.few_shot_strategy == 'train_projection' and st.session_state.final_prototypes.numel() > 0 ) if use_few_shot: st.subheader("Prediction using Prototypes") model_to_use = feature_extractor_model model_to_use.eval() strategy_for_pred = st.session_state.few_shot_strategy st.info(f"Using Few-Shot Strategy: {strategy_for_pred.replace('_', ' ').title()}") # Added info with torch.no_grad(): embedding = model_to_use(input_tensor) prototypes_for_pred = st.session_state.final_prototypes.to(DEVICE) if embedding.shape[1] != prototypes_for_pred.shape[1]: st.error(f"Dimension mismatch! Emb: {embedding.shape[1]}, Proto: {prototypes_for_pred.shape[1]}.") st.stop() distances = torch.cdist(embedding, prototypes_for_pred) pred_prototype_index = torch.argmin(distances, dim=1).item() # Use original label list from session state predicted_original_label = st.session_state.prototype_labels[pred_prototype_index] # Get current class names list (base+rare) _, _, current_class_names_pred, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) # Re-fetch current names if 0 <= predicted_original_label < len(current_class_names_pred): predicted_class_name = current_class_names_pred[predicted_original_label] confidence_scores = torch.softmax(-distances, dim=1) confidence = confidence_scores[0, pred_prototype_index].item() st.metric(label="Prediction (Prototype)", value=predicted_class_name, delta=f"{confidence * 100:.1f}% Confidence") st.info(f"(Matched prototype for class: '{predicted_class_name}' [Orig Label: {predicted_original_label}])") # Added info else: st.error(f"Predicted prototype label index {predicted_original_label} out of range for current classes ({len(current_class_names_pred)}).") else: st.subheader("Prediction using Standard Classifier") if st.session_state.model_mode != 'standard': st.warning("Falling back to Standard Classifier mode.") model_to_use = standard_classifier_model model_to_use.eval() with torch.no_grad(): outputs = model_to_use(input_tensor) probs = torch.softmax(outputs, dim=1) pred_label = torch.argmax(probs, dim=1).item() confidence = probs[0][pred_label].item() # Use current class names list (base+rare) and num_base_classes _, _, current_class_names_pred, num_base_classes_pred = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) # Re-fetch if 0 <= pred_label < num_base_classes_pred: # Compare against num_base_classes predicted_class_name = current_class_names_pred[pred_label] # Get name from full list st.metric(label="Prediction (Standard)", value=predicted_class_name, delta=f"{confidence * 100:.1f}% Confidence") else: st.error(f"Standard classifier predicted label {pred_label}, out of range for base classes ({num_base_classes_pred}).") except Exception as e: st.error(f"An error occurred during prediction: {e}") st.exception(e) elif option == "Detection": # (Original Code - Uses YOLO model from Git) st.header("🕵️ Object Detection with YOLO") uploaded_file_detect = st.file_uploader("Upload an image for detection", type=["jpg", "jpeg", "png"], key="detect_uploader") if uploaded_file_detect: try: image_detect = Image.open(uploaded_file_detect).convert("RGB") result_image, detections = detect_objects(image_detect) display_image = Image.fromarray(result_image) display_image = ImageOps.contain(display_image, (900, 700)) st.image(display_image, caption="Detection Result", use_container_width=True) if not detections.empty: st.subheader("📋 Detection Results:") detections['Confidence'] = detections['Confidence'].map('{:.1%}'.format) detections[['X_min', 'Y_min', 'X_max', 'Y_max']] = detections[['X_min', 'Y_min', 'X_max', 'Y_max']].round(1) st.dataframe(detections[['Class', 'Confidence', 'X_min', 'Y_min', 'X_max', 'Y_max']]) else: st.info("No objects detected.") except Exception as e: st.error(f"An error occurred during detection: {e}") st.exception(e) elif option == "Add/Manage Rare Classes": # (Original Code - interacts with RARE_DATASET which now points to /data/...) st.header("➕ Add New Rare Class (to Persistent Storage)") # Updated title n_shot_req = 2 n_query_req = 2 required_samples = n_shot_req + n_query_req st.write(f"Upload at least **{required_samples}** sample images. Images saved to `{RARE_DATASET}`.") # Show path with st.form("add_class_form"): new_class_name = st.text_input("Enter the name for the new rare class:") uploaded_files_rare = st.file_uploader( f"Upload {required_samples} or more images:", accept_multiple_files=True, type=["jpg", "jpeg", "png"], key="add_class_uploader" ) submitted_add = st.form_submit_button("Add Class") if submitted_add: valid = True if not new_class_name or not new_class_name.strip(): st.warning("Please enter a valid class name."); valid = False if len(uploaded_files_rare) < required_samples: st.warning(f"Please upload at least {required_samples} images."); valid = False if valid: sanitized_class_name = "".join(c for c in new_class_name if c.isalnum() or c in (' ', '_')).strip().replace(" ", "_") if not sanitized_class_name: st.error("Invalid class name after sanitization.") else: new_class_dir = os.path.join(RARE_DATASET, sanitized_class_name) # Path in /data/... if os.path.exists(new_class_dir): st.warning(f"Class directory '{sanitized_class_name}' already exists in {RARE_DATASET}.") # Updated msg else: try: os.makedirs(new_class_dir, exist_ok=True) # Create in /data/... image_save_errors = 0 for i, file in enumerate(uploaded_files_rare): try: img = Image.open(file).convert("RGB") base, ext = os.path.splitext(file.name) safe_base = "".join(c for c in base if c.isalnum() or c in ('_', '-')).strip()[:50] filename = f"{safe_base}_{random.randint(1000, 9999)}_{i+1}.jpg" save_path = os.path.join(new_class_dir, filename) # Save to /data/... img.save(save_path, format='JPEG', quality=95) except Exception as img_e: st.error(f"Error saving image {i+1} ({file.name}): {img_e}") image_save_errors += 1 if image_save_errors == 0: st.success(f"✅ Added class: '{sanitized_class_name}'. Re-run 'Train Few-Shot Model'.") # Simplified msg st.cache_data.clear() # Don't clear resource cache (models) unless needed # Reset state st.session_state.final_prototypes = None st.session_state.prototype_labels = None st.session_state.few_shot_strategy = None st.session_state.few_shot_trained = False st.session_state.model_mode = 'standard' st.rerun() else: st.error(f"Failed to save {image_save_errors} images.") except Exception as e: st.error(f"Error creating directory or saving images: {e}") st.exception(e) st.divider() st.header("❌ Delete a Rare Class (from Persistent Storage)") # Updated title try: if os.path.isdir(RARE_DATASET): # Check /data/Rare_data rare_class_dirs = [d for d in os.listdir(RARE_DATASET) if os.path.isdir(os.path.join(RARE_DATASET, d))] if not rare_class_dirs: st.info(f"No rare classes found to delete in {RARE_DATASET}.") # Updated msg else: with st.form("delete_class_form"): to_delete = st.selectbox("Select rare class to delete:", rare_class_dirs, key="delete_rare_select") confirm_delete_rare = st.checkbox(f"Confirm deletion of '{to_delete}' from persistent storage?", key="delete_rare_confirm") # Updated msg delete_submit_rare = st.form_submit_button("Delete Class") if delete_submit_rare: if confirm_delete_rare and to_delete: delete_path = os.path.join(RARE_DATASET, to_delete) # Path in /data/... try: shutil.rmtree(delete_path) # Deletes from /data/... st.success(f"✅ Deleted rare class: {to_delete}") st.cache_data.clear() # Reset state st.session_state.few_shot_trained = False st.session_state.final_prototypes = None st.session_state.prototype_labels = None st.session_state.few_shot_strategy = None st.session_state.model_mode = 'standard' st.rerun() except Exception as e: st.error(f"Error deleting directory {delete_path}: {e}") elif not confirm_delete_rare: st.warning("Please confirm the deletion.") else: st.info(f"Rare dataset directory '{RARE_DATASET}' does not exist.") # Updated msg except Exception as e: st.error(f"Error listing/deleting rare classes from {RARE_DATASET}: {e}") # Updated msg st.exception(e) elif option == "Train Few-Shot Model": # (Original Code - uses data potentially from /data/..., saves state to /data/...) st.header("🚀 Train Few-Shot Model") if len(class_names) < 2: st.error("Need at least two classes (Base + Rare combined) to perform few-shot training.") st.stop() # --- Training Parameters --- epochs = 10 n_way_train = len(class_names) # Original: Use all available classes episodes_per_epoch = 5 n_shot = 2 n_query = 2 learning_rate = 1e-4 weight_decay_proj = 1e-4 eligible_classes_check = [ cls_id for cls_id in class_indices if len(class_indices.get(cls_id, [])) >= (n_shot + n_query) ] if len(eligible_classes_check) < 2: st.error(f"Need >= 2 classes with {n_shot + n_query} samples. Found {len(eligible_classes_check)}.") st.stop() # Adjust n_way based on eligibility if using all classes was intended if n_way_train > len(eligible_classes_check): st.warning(f"Adjusting n-way from {n_way_train} to {len(eligible_classes_check)} based on eligible classes.") n_way_train = len(eligible_classes_check) # --- Training Form --- with st.form("train_form"): submitted_train = st.form_submit_button("Start Few-Shot Training") if submitted_train: active_strategy = 'train_projection' st.info(f"🚀 Starting few-shot process...") # Re-fetch/confirm data state current_combined_dataset_train, current_indices_train, current_names_train, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) current_eligible_train = [ cls_id for cls_id in current_indices_train if len(current_indices_train.get(cls_id, [])) >= (n_shot + n_query) ] if len(current_eligible_train) < 2: st.error(f"Error before starting: Need >= 2 eligible classes, found {len(current_eligible_train)}."); st.stop() n_way_for_episode = min(n_way_train, len(current_eligible_train)) # Final check on n_way if n_way_for_episode < 2: st.error("Error: Cannot create 2-way+ episodes."); st.stop() progress_bar = st.progress(0) status_placeholder = st.empty() chart_placeholder = st.empty() # Keep placeholder model_train = cached_feature_extractor_model() model_train.train() # --- Optimizer Setup (Original) --- try: for param in model_train.model.parameters(): param.requires_grad = False for param in model_train.projection.parameters(): param.requires_grad = True trainable_params = list(filter(lambda p: p.requires_grad, model_train.parameters())) if not trainable_params: st.error("No trainable params found!"); st.stop() except AttributeError as e: st.error(f"Model setup error: {e}"); st.stop() optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay_proj) # --- Training Loop (Original) --- loss_history = [] accuracy_history = [] total_steps = epochs * episodes_per_epoch current_step = 0 training_successful = False for epoch in range(epochs): epoch_loss = 0.0 epoch_accuracy = 0.0 valid_episodes_in_epoch = 0 for episode in range(episodes_per_epoch): current_step += 1 s_imgs, s_labels, q_imgs, q_labels = create_episode( current_combined_dataset_train, current_indices_train, current_names_train, n_way=n_way_for_episode, n_shot=n_shot, n_query=n_query ) if s_imgs is None: continue try: s_emb = model_train(s_imgs) q_emb = model_train(q_imgs) loss, accuracy = proto_loss(s_emb, s_labels, q_emb, q_labels) # Use original proto_loss except Exception as model_e: st.error(f"Model/Loss Error Ep {epoch+1}-{episode+1}: {model_e}"); continue if loss is not None and not torch.isnan(loss) and loss.requires_grad: try: optimizer.zero_grad(); loss.backward(); optimizer.step() epoch_loss += loss.item(); epoch_accuracy += accuracy; valid_episodes_in_epoch += 1 except Exception as optim_e: st.error(f"Optim Error Ep {epoch+1}-{episode+1}: {optim_e}") elif torch.isnan(loss): st.warning(f"NaN Loss Ep {epoch+1}-{episode+1}") if (episode + 1) % 5 == 0 or episode == episodes_per_epoch - 1: progress = current_step / total_steps progress_bar.progress(min(progress, 1.0)) # Log epoch results (Original logic) if valid_episodes_in_epoch > 0: avg_epoch_loss = epoch_loss / valid_episodes_in_epoch avg_epoch_accuracy = epoch_accuracy / valid_episodes_in_epoch loss_history.append(avg_epoch_loss) accuracy_history.append(avg_epoch_accuracy) status_placeholder.text(f"Epoch {epoch+1}/{epochs} | Loss: {avg_epoch_loss:.4f} | Acc: {avg_epoch_accuracy:.2%}") else: status_placeholder.text(f"Epoch {epoch+1}/{epochs} | No valid episodes.") loss_history.append(float('nan')); accuracy_history.append(float('nan')) status_placeholder.success("✅ Training Finished!") # --- Final Prototype Calculation (Original) --- st.info(f"Calculating final prototypes...") model_train.eval() st.cache_data.clear() final_combined_dataset_proto, _, final_class_names_proto, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) final_prototypes_tensor, final_prototype_labels = calculate_final_prototypes( model_train, final_combined_dataset_proto, final_class_names_proto, active_strategy ) # --- Store results in session state (Original) --- if final_prototypes_tensor is not None and final_prototype_labels is not None: st.session_state.final_prototypes = final_prototypes_tensor st.session_state.prototype_labels = final_prototype_labels st.session_state.few_shot_strategy = active_strategy st.session_state.few_shot_trained = True st.session_state.model_mode = 'few_shot' st.success(f"Prototypes Calculated.") training_successful = True else: st.session_state.final_prototypes = None; st.session_state.prototype_labels = None st.session_state.few_shot_strategy = None; st.session_state.few_shot_trained = False st.session_state.model_mode = 'standard'; st.error("Prototype calculation failed.") training_successful = False st.rerun() # Rerun to clear form # --- SAVING SECTION (Original logic - saves to SAVED_MODELS_DIR which is now /data/...) --- if st.session_state.get('final_prototypes') is not None and \ st.session_state.get('model_mode') == 'few_shot' and \ st.session_state.get('few_shot_strategy') == 'train_projection': st.divider() st.subheader("💾 Save Current Few-Shot State (to Persistent Storage)") # Updated title st.info(f"Saves current state to {SAVED_MODELS_DIR}") # Show path save_model_name = st.text_input("Enter name:", key="save_model_name_input_main") if st.button("Save State", key="save_state_button_main"): if save_model_name: model_to_save = cached_feature_extractor_model() model_to_save.eval() _, _, current_cls_names_for_saving, _ = get_combined_dataset_and_indices(BASE_DATASET, RARE_DATASET) # save_few_shot_state saves to /data/... save_successful = save_few_shot_state( save_model_name, model_to_save, st.session_state.final_prototypes, st.session_state.prototype_labels, current_cls_names_for_saving, st.session_state.few_shot_strategy ) if save_successful: st.rerun() # Refresh sidebar list else: st.warning("Please enter name.") # --- END OF `elif option == "Train Few-Shot Model":` ---