t0m-R commited on
Commit
18cc612
·
1 Parent(s): d07aa47

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +34 -11
README.md CHANGED
@@ -37,17 +37,45 @@ This approach significantly improves the model's ability to identify the subtle
37
  The following Python code shows how to load and use the model for inference.
38
 
39
  ```python
40
- from transformers import AutoModelForImageClassification
41
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # Load the model from the Hub
44
  model_name = "t0m-R/vit-stm-artifact-fft"
45
  model = AutoModelForImageClassification.from_pretrained(model_name)
46
 
47
- # NOTE: This model requires a custom FFT-based preprocessing function.
48
- # The 'preprocessed_image' tensor must have a shape of (1, 3, 224, 224).
49
- # See the "Preprocessing" section for details.
50
- # preprocessed_image = your_custom_fft_preprocessing_function("path/to/your/stm_image")
51
 
52
  # Run inference
53
  with torch.no_grad():
@@ -61,12 +89,7 @@ print(f"Predicted Label: {predicted_label}")
61
 
62
  ## Preprocessing
63
 
64
- **This model will not work with standard image preprocessing.** The input must be a 3-channel tensor representing the grayscale image, FFT amplitude, and FFT phase. Please refer to the original paper for the exact implementation details. The core steps involve:
65
-
66
- * Loading the image as grayscale and resizing it to 224x224.
67
- * Applying a 2D Fast Fourier Transform (`numpy.fft.fft2`).
68
- * Calculating the amplitude (`np.abs`) and phase (`np.angle`).
69
- * Normalizing and stacking the three channels into a single tensor.
70
 
71
  ## Training Data
72
 
 
37
  The following Python code shows how to load and use the model for inference.
38
 
39
  ```python
 
40
  import torch
41
+ import numpy as np
42
+ from PIL import Image
43
+ from transformers import AutoModelForImageClassification
44
+
45
+ def preprocess_for_artifact_detection(image_path):
46
+ """
47
+ Loads an STM image and converts it to the required 3-channel format
48
+ (grayscale, FFT amplitude, FFT phase) for the model.
49
+ """
50
+ # 1. Load and prepare grayscale channel
51
+ with Image.open(image_path) as img:
52
+ img = img.convert('L').resize((224, 224))
53
+ grayscale_img = np.array(img) / 255.0
54
+
55
+ # 2. Compute FFT, Amplitude, and Phase
56
+ fft_data = np.fft.fft2(grayscale_img)
57
+ fft_shifted = np.fft.fftshift(fft_data)
58
+
59
+ amplitude = np.log1p(np.abs(fft_shifted))
60
+ phase = np.angle(fft_shifted)
61
+
62
+ # 3. Normalize channels to be in a 0-1 range
63
+ amplitude = (amplitude - np.min(amplitude)) / (np.max(amplitude) - np.min(amplitude))
64
+ phase = (phase - np.min(phase)) / (np.max(phase) - np.min(phase))
65
+
66
+ # 4. Stack channels and convert to PyTorch tensor (C, H, W)
67
+ stacked_channels = np.stack([grayscale_img, amplitude, phase], axis=0)
68
+
69
+ # 5. Add a batch dimension (B, C, H, W) and return as float tensor
70
+ return torch.tensor(stacked_channels, dtype=torch.float32).unsqueeze(0)
71
 
72
  # Load the model from the Hub
73
  model_name = "t0m-R/vit-stm-artifact-fft"
74
  model = AutoModelForImageClassification.from_pretrained(model_name)
75
 
76
+ # Preprocess your image
77
+ image_path = "path/to/your/stm_image" # Replace with your image path
78
+ preprocessed_image = preprocess_for_artifact_detection(image_path)
 
79
 
80
  # Run inference
81
  with torch.no_grad():
 
89
 
90
  ## Preprocessing
91
 
92
+ **This model will not work with standard image preprocessing.** The input must be a 3-channel tensor representing the grayscale image, FFT amplitude, and FFT phase, as implemented in the function provided in the "How to Use" section.
 
 
 
 
 
93
 
94
  ## Training Data
95