t0m-R
commited on
Commit
·
90df10a
1
Parent(s):
f921783
Update README.md
Browse files
README.md
CHANGED
|
@@ -44,36 +44,35 @@ from transformers import AutoModelForImageClassification
|
|
| 44 |
|
| 45 |
def preprocess_for_artifact_detection(image_path):
|
| 46 |
"""
|
| 47 |
-
Loads an STM image and converts it to the required 3-channel format
|
| 48 |
-
(grayscale,
|
| 49 |
"""
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
| 56 |
fft_data = np.fft.fft2(grayscale_img)
|
| 57 |
fft_shifted = np.fft.fftshift(fft_data)
|
| 58 |
|
| 59 |
-
|
| 60 |
phase = np.angle(fft_shifted)
|
| 61 |
|
| 62 |
-
#
|
| 63 |
-
|
| 64 |
-
phase = (phase - np.min(phase)) / (np.max(phase) - np.min(phase))
|
| 65 |
-
|
| 66 |
-
# 4. Stack channels and convert to PyTorch tensor (C, H, W)
|
| 67 |
-
stacked_channels = np.stack([grayscale_img, amplitude, phase], axis=0)
|
| 68 |
|
| 69 |
-
#
|
| 70 |
return torch.tensor(stacked_channels, dtype=torch.float32).unsqueeze(0)
|
| 71 |
|
| 72 |
# Load the model from the Hub
|
| 73 |
model_name = "t0m-R/vit-stm-artifact-fft"
|
| 74 |
model = AutoModelForImageClassification.from_pretrained(model_name)
|
| 75 |
|
| 76 |
-
# Preprocess
|
| 77 |
image_path = "path/to/your/stm_image" # Replace with your image path
|
| 78 |
preprocessed_image = preprocess_for_artifact_detection(image_path)
|
| 79 |
|
|
|
|
| 44 |
|
| 45 |
def preprocess_for_artifact_detection(image_path):
|
| 46 |
"""
|
| 47 |
+
Loads an STM image and converts it to the required 3-channel format
|
| 48 |
+
(grayscale, magnitude spectrum, phase) for the model.
|
| 49 |
"""
|
| 50 |
+
try:
|
| 51 |
+
with Image.open(image_path) as img:
|
| 52 |
+
img = img.convert('L').resize((224, 224))
|
| 53 |
+
grayscale_img = np.array(img) / 255.0
|
| 54 |
+
except FileNotFoundError:
|
| 55 |
+
print(f"Error: The file at {image_path} was not found.")
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
# Compute FFT, Magnitude, and Phase
|
| 59 |
fft_data = np.fft.fft2(grayscale_img)
|
| 60 |
fft_shifted = np.fft.fftshift(fft_data)
|
| 61 |
|
| 62 |
+
magnitude_spectrum = np.abs(fft_shifted)
|
| 63 |
phase = np.angle(fft_shifted)
|
| 64 |
|
| 65 |
+
# Stack channels and convert to PyTorch tensor (C, H, W)
|
| 66 |
+
stacked_channels = np.stack([grayscale_img, magnitude_spectrum, phase], axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
+
# Add a batch dimension (B, C, H, W) and return as float tensor
|
| 69 |
return torch.tensor(stacked_channels, dtype=torch.float32).unsqueeze(0)
|
| 70 |
|
| 71 |
# Load the model from the Hub
|
| 72 |
model_name = "t0m-R/vit-stm-artifact-fft"
|
| 73 |
model = AutoModelForImageClassification.from_pretrained(model_name)
|
| 74 |
|
| 75 |
+
# Preprocess
|
| 76 |
image_path = "path/to/your/stm_image" # Replace with your image path
|
| 77 |
preprocessed_image = preprocess_for_artifact_detection(image_path)
|
| 78 |
|