--- license: apache-2.0 language: en tags: - image-classification - vision-transformer - pytorch - stm - materials-science - nffa-di base_model: - google/vit-base-patch32-224-in21k pipeline_tag: image-classification --- # Vision Transformer for STM Multi-Tip Artifact Detection This is a fine-tuned **Vision Transformer (ViT-B/32)** model for classifying Scanning Tunneling Microscopy (STM) images. It is designed to detect the presence of **multi-tip artifacts**, a common distortion that results in duplicated signals and complicates data interpretation. This model was developed as part of the **NFFA-DI (Nano Foundries and Fine Analysis Digital Infrastructure)** project, funded by the European Union's NextGenerationEU program. ## Model Description The model is a `ViT-B/32` pre-trained on ImageNet-21k. It was fine-tuned to classify an STM image as either `Artifact-Free` or `Multi-Tip Artifact`. A key feature of this model is its use of a **Fast Fourier Transform (FFT)** based preprocessing method. The model's input is not a standard image but a 3-channel tensor composed of: 1. The grayscale STM image. 2. The **amplitude** of the image's Fourier transform. 3. The **phase** of the image's Fourier transform. This approach significantly improves the model's ability to identify the subtle patterns characteristic of multi-tip artifacts. ## How to Use The following Python code shows how to load and use the model for inference. ```python import torch import numpy as np from PIL import Image from transformers import AutoModelForImageClassification def preprocess_for_artifact_detection(image_path): """ Loads an STM image and converts it to the required 3-channel format (grayscale, magnitude spectrum, phase) for the model. """ try: with Image.open(image_path) as img: img = img.convert('L').resize((224, 224)) grayscale_img = np.array(img) / 255.0 except FileNotFoundError: print(f"Error: The file at {image_path} was not found.") return None # Compute FFT, Magnitude, and Phase fft_data = np.fft.fft2(grayscale_img) fft_shifted = np.fft.fftshift(fft_data) magnitude_spectrum = np.abs(fft_shifted) phase = np.angle(fft_shifted) # Stack channels and convert to PyTorch tensor (C, H, W) stacked_channels = np.stack([grayscale_img, magnitude_spectrum, phase], axis=0) # Add a batch dimension (B, C, H, W) and return as float tensor return torch.tensor(stacked_channels, dtype=torch.float32).unsqueeze(0) # Load the model from the Hub model_name = "t0m-R/vit-stm-artifact-fft" model = AutoModelForImageClassification.from_pretrained(model_name) # Preprocess image_path = "path/to/your/stm_image" # Replace with your image path preprocessed_image = preprocess_for_artifact_detection(image_path) # Run inference with torch.no_grad(): logits = model(preprocessed_image).logits predicted_label_id = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_label_id] print(f"Predicted Label: {predicted_label}") # Expected output: "Predicted Label: Multi-Tip Artifact" ``` ## Preprocessing **This model will not work with standard image preprocessing.** The input must be a 3-channel tensor representing the grayscale image, FFT amplitude, and FFT phase, as implemented in the function provided in the "How to Use" section. ## Training Data The model was fine-tuned on a synthetic dataset generated from experimental STM images recorded at CNR-IOM, Trieste. Artifact-free images were transformed into synthetic multi-tip images by summing the clean image with translated and intensity-scaled versions of itself. ## Citation If you use this model in your research, please cite the original work: ```bibtex @article{rodani2024enhancing, title={Enhancing Multi-Tip Artifact Detection in STM Images Using Fourier Transform and Vision Transformers}, author={Rodani, Tommaso and Ansuini, Alessio and Cazziga, Alberto}, journal={Accepted at the 1st Machine Learning for Life and Material Sciences Workshop at ICML}, year={2024} } ```