Upload 2 files
Browse files- v2/README.md +95 -0
- v2/infer.py +84 -0
v2/README.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal Inference Setup
|
| 2 |
+
|
| 3 |
+
This project provides a lightweight setup for running inference with a pre-trained model.
|
| 4 |
+
It contains the model configuration, trained weights, and a Python script to perform inference.
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## Project Structure
|
| 9 |
+
|
| 10 |
+
```
|
| 11 |
+
.
|
| 12 |
+
├── model/
|
| 13 |
+
│ ├── config.json # Model configuration file
|
| 14 |
+
│ ├── model.safetensors # Pre-trained model weights
|
| 15 |
+
└── infer.py # Script to run inference on input data
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## Prerequisites
|
| 21 |
+
|
| 22 |
+
- Python 3.8+
|
| 23 |
+
- PyTorch
|
| 24 |
+
- Transformers library
|
| 25 |
+
- safetensors
|
| 26 |
+
- PIL (Pillow)
|
| 27 |
+
- (Optional) tkinter if a GUI is implemented in `infer.py`
|
| 28 |
+
|
| 29 |
+
Install required packages:
|
| 30 |
+
|
| 31 |
+
```bash
|
| 32 |
+
pip install torch transformers safetensors pillow
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Files Description
|
| 38 |
+
|
| 39 |
+
### model/config.json
|
| 40 |
+
Defines the architecture and hyperparameters of the model (e.g., hidden size, number of layers, vocabulary size).
|
| 41 |
+
|
| 42 |
+
Required to correctly instantiate the model before loading the weights.
|
| 43 |
+
|
| 44 |
+
### model/model.safetensors
|
| 45 |
+
Contains the trained weights of the model.
|
| 46 |
+
|
| 47 |
+
Stored in the Safetensors format for safety and efficiency.
|
| 48 |
+
|
| 49 |
+
### infer.py
|
| 50 |
+
Main script to perform inference with the pre-trained model.
|
| 51 |
+
|
| 52 |
+
**Responsibilities:**
|
| 53 |
+
- Loads config.json and model.safetensors
|
| 54 |
+
- Preprocesses input text/image (depending on model type)
|
| 55 |
+
- Runs the model forward pass
|
| 56 |
+
- Outputs predictions
|
| 57 |
+
|
| 58 |
+
**Usage:**
|
| 59 |
+
```bash
|
| 60 |
+
python infer.py --input "your input text or path to image"
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
**Example:**
|
| 64 |
+
```bash
|
| 65 |
+
python infer.py --input "Hello, how are you?"
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## Usage Workflow
|
| 71 |
+
|
| 72 |
+
1. Place the model files (`config.json` and `model.safetensors`) inside the `model/` directory.
|
| 73 |
+
2. Run `infer.py` with your desired input.
|
| 74 |
+
3. The script will display the prediction/classification result.
|
| 75 |
+
|
| 76 |
+
---
|
| 77 |
+
|
| 78 |
+
## Notes
|
| 79 |
+
|
| 80 |
+
- Ensure the model files are compatible (same checkpoint version).
|
| 81 |
+
- For image-based models, inputs must be resized to the expected dimensions (e.g., 224x224 RGB).
|
| 82 |
+
- For text-based models, ensure the tokenizer is compatible with the config (may require adding tokenizer files).
|
| 83 |
+
- GPU is recommended for faster inference, but CPU is supported.
|
| 84 |
+
|
| 85 |
+
---
|
| 86 |
+
|
| 87 |
+
## License
|
| 88 |
+
|
| 89 |
+
[Add license information here if applicable]
|
| 90 |
+
|
| 91 |
+
---
|
| 92 |
+
|
| 93 |
+
## Contributing
|
| 94 |
+
|
| 95 |
+
[Add contribution guidelines here if applicable]
|
v2/infer.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tkinter as tk
|
| 2 |
+
from tkinter import filedialog, messagebox
|
| 3 |
+
from PIL import Image, ImageTk
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoImageProcessor, SiglipForImageClassification
|
| 6 |
+
|
| 7 |
+
# Model and processor paths (adjust if needed; assumes final model saved in output_dir)
|
| 8 |
+
model_path = "./model" # Path to the fine-tuned model
|
| 9 |
+
|
| 10 |
+
# Load processor and model
|
| 11 |
+
processor = AutoImageProcessor.from_pretrained("./base model")
|
| 12 |
+
model = SiglipForImageClassification.from_pretrained(model_path)
|
| 13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
model.to(device)
|
| 15 |
+
model.eval()
|
| 16 |
+
|
| 17 |
+
# Get label mappings from model config
|
| 18 |
+
id2label = model.config.id2label
|
| 19 |
+
|
| 20 |
+
# Tkinter GUI
|
| 21 |
+
class ImageClassifierApp:
|
| 22 |
+
def __init__(self, root):
|
| 23 |
+
self.root = root
|
| 24 |
+
self.root.title("SigLIP2 Gardner Grading Classifier")
|
| 25 |
+
self.root.geometry("600x600")
|
| 26 |
+
|
| 27 |
+
# Label for instructions
|
| 28 |
+
self.instruction_label = tk.Label(root, text="Select an image to classify")
|
| 29 |
+
self.instruction_label.pack(pady=10)
|
| 30 |
+
|
| 31 |
+
# Button to load image
|
| 32 |
+
self.load_button = tk.Button(root, text="Load Image", command=self.load_image)
|
| 33 |
+
self.load_button.pack(pady=10)
|
| 34 |
+
|
| 35 |
+
# Canvas to display image
|
| 36 |
+
self.image_canvas = tk.Canvas(root, width=400, height=400, bg="white")
|
| 37 |
+
self.image_canvas.pack(pady=10)
|
| 38 |
+
|
| 39 |
+
# Label to display prediction
|
| 40 |
+
self.prediction_label = tk.Label(root, text="", font=("Arial", 14))
|
| 41 |
+
self.prediction_label.pack(pady=10)
|
| 42 |
+
|
| 43 |
+
def load_image(self):
|
| 44 |
+
file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.png *.jpg *.jpeg")])
|
| 45 |
+
if file_path:
|
| 46 |
+
try:
|
| 47 |
+
# Open and convert image to RGB
|
| 48 |
+
img = Image.open(file_path).convert("RGB")
|
| 49 |
+
img_resized = img.resize((400, 400)) # For display
|
| 50 |
+
self.photo_img = ImageTk.PhotoImage(img_resized)
|
| 51 |
+
self.image_canvas.create_image(200, 200, image=self.photo_img)
|
| 52 |
+
|
| 53 |
+
# Preprocess with explicit settings
|
| 54 |
+
inputs = processor(
|
| 55 |
+
images=img,
|
| 56 |
+
return_tensors="pt",
|
| 57 |
+
do_resize=True,
|
| 58 |
+
size={"height": 224, "width": 224}, # Adjust based on model's expected size (common for SigLIP)
|
| 59 |
+
do_normalize=True
|
| 60 |
+
).to(device)
|
| 61 |
+
|
| 62 |
+
# Inference
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
outputs = model(**inputs)
|
| 65 |
+
logits = outputs.logits
|
| 66 |
+
probabilities = torch.softmax(logits, dim=-1)
|
| 67 |
+
max_prob, predicted_id = probabilities.max(dim=-1)
|
| 68 |
+
predicted_label = id2label[predicted_id.item()]
|
| 69 |
+
|
| 70 |
+
# Heuristic: If confidence is low, classify as "Not an Embryo"
|
| 71 |
+
confidence_threshold = 0.45# Adjust this threshold as needed
|
| 72 |
+
if max_prob.item() < confidence_threshold:
|
| 73 |
+
predicted_label = "Not an Embryo"
|
| 74 |
+
|
| 75 |
+
# Display prediction
|
| 76 |
+
self.prediction_label.config(text=f"Predicted Grade: {predicted_label} (Confidence: {max_prob.item():.2f})")
|
| 77 |
+
|
| 78 |
+
except Exception as e:
|
| 79 |
+
messagebox.showerror("Error", f"Failed to process image: {str(e)}")
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
root = tk.Tk()
|
| 83 |
+
app = ImageClassifierApp(root)
|
| 84 |
+
root.mainloop()
|