t0m-R
commited on
Commit
·
18cc612
1
Parent(s):
d07aa47
Update README.md
Browse files
README.md
CHANGED
|
@@ -37,17 +37,45 @@ This approach significantly improves the model's ability to identify the subtle
|
|
| 37 |
The following Python code shows how to load and use the model for inference.
|
| 38 |
|
| 39 |
```python
|
| 40 |
-
from transformers import AutoModelForImageClassification
|
| 41 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
# Load the model from the Hub
|
| 44 |
model_name = "t0m-R/vit-stm-artifact-fft"
|
| 45 |
model = AutoModelForImageClassification.from_pretrained(model_name)
|
| 46 |
|
| 47 |
-
#
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
# preprocessed_image = your_custom_fft_preprocessing_function("path/to/your/stm_image")
|
| 51 |
|
| 52 |
# Run inference
|
| 53 |
with torch.no_grad():
|
|
@@ -61,12 +89,7 @@ print(f"Predicted Label: {predicted_label}")
|
|
| 61 |
|
| 62 |
## Preprocessing
|
| 63 |
|
| 64 |
-
**This model will not work with standard image preprocessing.** The input must be a 3-channel tensor representing the grayscale image, FFT amplitude, and FFT phase
|
| 65 |
-
|
| 66 |
-
* Loading the image as grayscale and resizing it to 224x224.
|
| 67 |
-
* Applying a 2D Fast Fourier Transform (`numpy.fft.fft2`).
|
| 68 |
-
* Calculating the amplitude (`np.abs`) and phase (`np.angle`).
|
| 69 |
-
* Normalizing and stacking the three channels into a single tensor.
|
| 70 |
|
| 71 |
## Training Data
|
| 72 |
|
|
|
|
| 37 |
The following Python code shows how to load and use the model for inference.
|
| 38 |
|
| 39 |
```python
|
|
|
|
| 40 |
import torch
|
| 41 |
+
import numpy as np
|
| 42 |
+
from PIL import Image
|
| 43 |
+
from transformers import AutoModelForImageClassification
|
| 44 |
+
|
| 45 |
+
def preprocess_for_artifact_detection(image_path):
|
| 46 |
+
"""
|
| 47 |
+
Loads an STM image and converts it to the required 3-channel format
|
| 48 |
+
(grayscale, FFT amplitude, FFT phase) for the model.
|
| 49 |
+
"""
|
| 50 |
+
# 1. Load and prepare grayscale channel
|
| 51 |
+
with Image.open(image_path) as img:
|
| 52 |
+
img = img.convert('L').resize((224, 224))
|
| 53 |
+
grayscale_img = np.array(img) / 255.0
|
| 54 |
+
|
| 55 |
+
# 2. Compute FFT, Amplitude, and Phase
|
| 56 |
+
fft_data = np.fft.fft2(grayscale_img)
|
| 57 |
+
fft_shifted = np.fft.fftshift(fft_data)
|
| 58 |
+
|
| 59 |
+
amplitude = np.log1p(np.abs(fft_shifted))
|
| 60 |
+
phase = np.angle(fft_shifted)
|
| 61 |
+
|
| 62 |
+
# 3. Normalize channels to be in a 0-1 range
|
| 63 |
+
amplitude = (amplitude - np.min(amplitude)) / (np.max(amplitude) - np.min(amplitude))
|
| 64 |
+
phase = (phase - np.min(phase)) / (np.max(phase) - np.min(phase))
|
| 65 |
+
|
| 66 |
+
# 4. Stack channels and convert to PyTorch tensor (C, H, W)
|
| 67 |
+
stacked_channels = np.stack([grayscale_img, amplitude, phase], axis=0)
|
| 68 |
+
|
| 69 |
+
# 5. Add a batch dimension (B, C, H, W) and return as float tensor
|
| 70 |
+
return torch.tensor(stacked_channels, dtype=torch.float32).unsqueeze(0)
|
| 71 |
|
| 72 |
# Load the model from the Hub
|
| 73 |
model_name = "t0m-R/vit-stm-artifact-fft"
|
| 74 |
model = AutoModelForImageClassification.from_pretrained(model_name)
|
| 75 |
|
| 76 |
+
# Preprocess your image
|
| 77 |
+
image_path = "path/to/your/stm_image" # Replace with your image path
|
| 78 |
+
preprocessed_image = preprocess_for_artifact_detection(image_path)
|
|
|
|
| 79 |
|
| 80 |
# Run inference
|
| 81 |
with torch.no_grad():
|
|
|
|
| 89 |
|
| 90 |
## Preprocessing
|
| 91 |
|
| 92 |
+
**This model will not work with standard image preprocessing.** The input must be a 3-channel tensor representing the grayscale image, FFT amplitude, and FFT phase, as implemented in the function provided in the "How to Use" section.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
## Training Data
|
| 95 |
|