ANISH-j commited on
Commit
dc4e3e8
·
verified ·
1 Parent(s): 5df6d26

Upload 2 files

Browse files
Files changed (2) hide show
  1. v2/README.md +95 -0
  2. v2/infer.py +84 -0
v2/README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal Inference Setup
2
+
3
+ This project provides a lightweight setup for running inference with a pre-trained model.
4
+ It contains the model configuration, trained weights, and a Python script to perform inference.
5
+
6
+ ---
7
+
8
+ ## Project Structure
9
+
10
+ ```
11
+ .
12
+ ├── model/
13
+ │ ├── config.json # Model configuration file
14
+ │ ├── model.safetensors # Pre-trained model weights
15
+ └── infer.py # Script to run inference on input data
16
+ ```
17
+
18
+ ---
19
+
20
+ ## Prerequisites
21
+
22
+ - Python 3.8+
23
+ - PyTorch
24
+ - Transformers library
25
+ - safetensors
26
+ - PIL (Pillow)
27
+ - (Optional) tkinter if a GUI is implemented in `infer.py`
28
+
29
+ Install required packages:
30
+
31
+ ```bash
32
+ pip install torch transformers safetensors pillow
33
+ ```
34
+
35
+ ---
36
+
37
+ ## Files Description
38
+
39
+ ### model/config.json
40
+ Defines the architecture and hyperparameters of the model (e.g., hidden size, number of layers, vocabulary size).
41
+
42
+ Required to correctly instantiate the model before loading the weights.
43
+
44
+ ### model/model.safetensors
45
+ Contains the trained weights of the model.
46
+
47
+ Stored in the Safetensors format for safety and efficiency.
48
+
49
+ ### infer.py
50
+ Main script to perform inference with the pre-trained model.
51
+
52
+ **Responsibilities:**
53
+ - Loads config.json and model.safetensors
54
+ - Preprocesses input text/image (depending on model type)
55
+ - Runs the model forward pass
56
+ - Outputs predictions
57
+
58
+ **Usage:**
59
+ ```bash
60
+ python infer.py --input "your input text or path to image"
61
+ ```
62
+
63
+ **Example:**
64
+ ```bash
65
+ python infer.py --input "Hello, how are you?"
66
+ ```
67
+
68
+ ---
69
+
70
+ ## Usage Workflow
71
+
72
+ 1. Place the model files (`config.json` and `model.safetensors`) inside the `model/` directory.
73
+ 2. Run `infer.py` with your desired input.
74
+ 3. The script will display the prediction/classification result.
75
+
76
+ ---
77
+
78
+ ## Notes
79
+
80
+ - Ensure the model files are compatible (same checkpoint version).
81
+ - For image-based models, inputs must be resized to the expected dimensions (e.g., 224x224 RGB).
82
+ - For text-based models, ensure the tokenizer is compatible with the config (may require adding tokenizer files).
83
+ - GPU is recommended for faster inference, but CPU is supported.
84
+
85
+ ---
86
+
87
+ ## License
88
+
89
+ [Add license information here if applicable]
90
+
91
+ ---
92
+
93
+ ## Contributing
94
+
95
+ [Add contribution guidelines here if applicable]
v2/infer.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tkinter as tk
2
+ from tkinter import filedialog, messagebox
3
+ from PIL import Image, ImageTk
4
+ import torch
5
+ from transformers import AutoImageProcessor, SiglipForImageClassification
6
+
7
+ # Model and processor paths (adjust if needed; assumes final model saved in output_dir)
8
+ model_path = "./model" # Path to the fine-tuned model
9
+
10
+ # Load processor and model
11
+ processor = AutoImageProcessor.from_pretrained("./base model")
12
+ model = SiglipForImageClassification.from_pretrained(model_path)
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ model.to(device)
15
+ model.eval()
16
+
17
+ # Get label mappings from model config
18
+ id2label = model.config.id2label
19
+
20
+ # Tkinter GUI
21
+ class ImageClassifierApp:
22
+ def __init__(self, root):
23
+ self.root = root
24
+ self.root.title("SigLIP2 Gardner Grading Classifier")
25
+ self.root.geometry("600x600")
26
+
27
+ # Label for instructions
28
+ self.instruction_label = tk.Label(root, text="Select an image to classify")
29
+ self.instruction_label.pack(pady=10)
30
+
31
+ # Button to load image
32
+ self.load_button = tk.Button(root, text="Load Image", command=self.load_image)
33
+ self.load_button.pack(pady=10)
34
+
35
+ # Canvas to display image
36
+ self.image_canvas = tk.Canvas(root, width=400, height=400, bg="white")
37
+ self.image_canvas.pack(pady=10)
38
+
39
+ # Label to display prediction
40
+ self.prediction_label = tk.Label(root, text="", font=("Arial", 14))
41
+ self.prediction_label.pack(pady=10)
42
+
43
+ def load_image(self):
44
+ file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.png *.jpg *.jpeg")])
45
+ if file_path:
46
+ try:
47
+ # Open and convert image to RGB
48
+ img = Image.open(file_path).convert("RGB")
49
+ img_resized = img.resize((400, 400)) # For display
50
+ self.photo_img = ImageTk.PhotoImage(img_resized)
51
+ self.image_canvas.create_image(200, 200, image=self.photo_img)
52
+
53
+ # Preprocess with explicit settings
54
+ inputs = processor(
55
+ images=img,
56
+ return_tensors="pt",
57
+ do_resize=True,
58
+ size={"height": 224, "width": 224}, # Adjust based on model's expected size (common for SigLIP)
59
+ do_normalize=True
60
+ ).to(device)
61
+
62
+ # Inference
63
+ with torch.no_grad():
64
+ outputs = model(**inputs)
65
+ logits = outputs.logits
66
+ probabilities = torch.softmax(logits, dim=-1)
67
+ max_prob, predicted_id = probabilities.max(dim=-1)
68
+ predicted_label = id2label[predicted_id.item()]
69
+
70
+ # Heuristic: If confidence is low, classify as "Not an Embryo"
71
+ confidence_threshold = 0.45# Adjust this threshold as needed
72
+ if max_prob.item() < confidence_threshold:
73
+ predicted_label = "Not an Embryo"
74
+
75
+ # Display prediction
76
+ self.prediction_label.config(text=f"Predicted Grade: {predicted_label} (Confidence: {max_prob.item():.2f})")
77
+
78
+ except Exception as e:
79
+ messagebox.showerror("Error", f"Failed to process image: {str(e)}")
80
+
81
+ if __name__ == "__main__":
82
+ root = tk.Tk()
83
+ app = ImageClassifierApp(root)
84
+ root.mainloop()