| | import torch |
| | import torch.nn.functional as F |
| | from datetime import datetime |
| | import os |
| | import csv |
| |
|
| | def few_shot_fault_classification( |
| | model, |
| | test_images, |
| | test_image_filenames, |
| | nominal_images, |
| | nominal_descriptions, |
| | defective_images, |
| | defective_descriptions, |
| | num_few_shot_nominal_imgs: int, |
| | device="cpu", |
| | file_path: str = '.', |
| | file_name: str = 'image_classification_results.csv', |
| | print_one_liner: bool = False |
| | ): |
| | """ |
| | Classify test images as nominal or defective based on similarity to nominal and defective images. |
| | """ |
| | |
| | if not isinstance(test_images, list): |
| | test_images = [test_images] |
| | if not isinstance(test_image_filenames, list): |
| | test_image_filenames = [test_image_filenames] |
| | if not isinstance(nominal_images, list): |
| | nominal_images = [nominal_images] |
| | if not isinstance(nominal_descriptions, list): |
| | nominal_descriptions = [nominal_descriptions] |
| | if not isinstance(defective_images, list): |
| | defective_images = [defective_images] |
| | if not isinstance(defective_descriptions, list): |
| | defective_descriptions = [defective_descriptions] |
| |
|
| | |
| | os.makedirs(file_path, exist_ok=True) |
| | |
| | |
| | csv_file = os.path.join(file_path, file_name) |
| | results = [] |
| |
|
| | with torch.no_grad(): |
| | |
| | nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images]) |
| | nominal_features /= nominal_features.norm(dim=-1, keepdim=True) |
| |
|
| | |
| | defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images]) |
| | defective_features /= defective_features.norm(dim=-1, keepdim=True) |
| |
|
| | |
| | csv_data = [] |
| |
|
| | |
| | for idx, test_img in enumerate(test_images): |
| | test_features = model.encode_image(test_img.to(device)) |
| | test_features /= test_features.norm(dim=-1, keepdim=True) |
| |
|
| | |
| | max_nominal_similarity = -float('inf') |
| | max_defective_similarity = -float('inf') |
| | max_nominal_idx = -1 |
| | max_defective_idx = -1 |
| |
|
| | |
| | for i in range(nominal_features.shape[0]): |
| | similarity = (test_features @ nominal_features[i].T).item() |
| | if similarity > max_nominal_similarity: |
| | max_nominal_similarity = similarity |
| | max_nominal_idx = i |
| |
|
| | |
| | for j in range(defective_features.shape[0]): |
| | similarity = (test_features @ defective_features[j].T).item() |
| | if similarity > max_defective_similarity: |
| | max_defective_similarity = similarity |
| | max_defective_idx = j |
| |
|
| | |
| | similarities = torch.tensor([max_nominal_similarity, max_defective_similarity]) |
| | probabilities = F.softmax(similarities, dim=0).tolist() |
| | prob_not_defective = probabilities[0] |
| | prob_defective = probabilities[1] |
| |
|
| | |
| | classification = "Defective" if prob_defective > prob_not_defective else "Nominal" |
| |
|
| | |
| | result = { |
| | "datetime_of_operation": datetime.now().isoformat(), |
| | "num_few_shot_nominal_imgs": num_few_shot_nominal_imgs, |
| | "image_path": test_image_filenames[idx], |
| | "image_name": test_image_filenames[idx].split('/')[-1], |
| | "classification_result": classification, |
| | "non_defect_prob": round(prob_not_defective, 3), |
| | "defect_prob": round(prob_defective, 3), |
| | "nominal_description": nominal_descriptions[max_nominal_idx], |
| | "defective_description": defective_descriptions[max_defective_idx], |
| | "max_nominal_similarity": round(max_nominal_similarity, 3), |
| | "max_defective_similarity": round(max_defective_similarity, 3) |
| | } |
| | |
| | csv_data.append(result) |
| | results.append(result) |
| |
|
| | |
| | if print_one_liner: |
| | print(f"{test_image_filenames[idx]} → {classification} " |
| | f"(Nominal: {prob_not_defective:.3f}, Defective: {prob_defective:.3f})") |
| |
|
| | |
| | file_exists = os.path.isfile(csv_file) |
| | with open(csv_file, mode='a' if file_exists else 'w', newline='') as file: |
| | fieldnames = [ |
| | "datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name", |
| | "classification_result", "non_defect_prob", "defect_prob", |
| | "nominal_description", "defective_description", |
| | "max_nominal_similarity", "max_defective_similarity" |
| | ] |
| | writer = csv.DictWriter(file, fieldnames=fieldnames) |
| |
|
| | |
| | if not file_exists: |
| | writer.writeheader() |
| |
|
| | |
| | for row in csv_data: |
| | writer.writerow(row) |
| |
|
| | return results |