t0m-R commited on
Commit
90df10a
·
1 Parent(s): f921783

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -17
README.md CHANGED
@@ -44,36 +44,35 @@ from transformers import AutoModelForImageClassification
44
 
45
  def preprocess_for_artifact_detection(image_path):
46
  """
47
- Loads an STM image and converts it to the required 3-channel format
48
- (grayscale, FFT amplitude, FFT phase) for the model.
49
  """
50
- # 1. Load and prepare grayscale channel
51
- with Image.open(image_path) as img:
52
- img = img.convert('L').resize((224, 224))
53
- grayscale_img = np.array(img) / 255.0
54
-
55
- # 2. Compute FFT, Amplitude, and Phase
 
 
 
56
  fft_data = np.fft.fft2(grayscale_img)
57
  fft_shifted = np.fft.fftshift(fft_data)
58
 
59
- amplitude = np.log1p(np.abs(fft_shifted))
60
  phase = np.angle(fft_shifted)
61
 
62
- # 3. Normalize channels to be in a 0-1 range
63
- amplitude = (amplitude - np.min(amplitude)) / (np.max(amplitude) - np.min(amplitude))
64
- phase = (phase - np.min(phase)) / (np.max(phase) - np.min(phase))
65
-
66
- # 4. Stack channels and convert to PyTorch tensor (C, H, W)
67
- stacked_channels = np.stack([grayscale_img, amplitude, phase], axis=0)
68
 
69
- # 5. Add a batch dimension (B, C, H, W) and return as float tensor
70
  return torch.tensor(stacked_channels, dtype=torch.float32).unsqueeze(0)
71
 
72
  # Load the model from the Hub
73
  model_name = "t0m-R/vit-stm-artifact-fft"
74
  model = AutoModelForImageClassification.from_pretrained(model_name)
75
 
76
- # Preprocess your image
77
  image_path = "path/to/your/stm_image" # Replace with your image path
78
  preprocessed_image = preprocess_for_artifact_detection(image_path)
79
 
 
44
 
45
  def preprocess_for_artifact_detection(image_path):
46
  """
47
+ Loads an STM image and converts it to the required 3-channel format
48
+ (grayscale, magnitude spectrum, phase) for the model.
49
  """
50
+ try:
51
+ with Image.open(image_path) as img:
52
+ img = img.convert('L').resize((224, 224))
53
+ grayscale_img = np.array(img) / 255.0
54
+ except FileNotFoundError:
55
+ print(f"Error: The file at {image_path} was not found.")
56
+ return None
57
+
58
+ # Compute FFT, Magnitude, and Phase
59
  fft_data = np.fft.fft2(grayscale_img)
60
  fft_shifted = np.fft.fftshift(fft_data)
61
 
62
+ magnitude_spectrum = np.abs(fft_shifted)
63
  phase = np.angle(fft_shifted)
64
 
65
+ # Stack channels and convert to PyTorch tensor (C, H, W)
66
+ stacked_channels = np.stack([grayscale_img, magnitude_spectrum, phase], axis=0)
 
 
 
 
67
 
68
+ # Add a batch dimension (B, C, H, W) and return as float tensor
69
  return torch.tensor(stacked_channels, dtype=torch.float32).unsqueeze(0)
70
 
71
  # Load the model from the Hub
72
  model_name = "t0m-R/vit-stm-artifact-fft"
73
  model = AutoModelForImageClassification.from_pretrained(model_name)
74
 
75
+ # Preprocess
76
  image_path = "path/to/your/stm_image" # Replace with your image path
77
  preprocessed_image = preprocess_for_artifact_detection(image_path)
78