Meeteshn commited on
Commit
d223bdb
·
verified ·
1 Parent(s): 7d80d87

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +95 -23
README.md CHANGED
@@ -60,35 +60,100 @@ pip install torch torchvision transformers scikit-learn pillow joblib numpy hugg
60
  ### Single Image Inference
61
 
62
  ```python
 
63
  import json
64
  import joblib
65
  from pathlib import Path
66
  from PIL import Image
67
  import torch
68
  import numpy as np
69
- from huggingface_hub import hf_hub_download
70
  from transformers import AutoImageProcessor, ViTModel
 
71
 
72
- # Configuration
73
- REPO_ID = "Meeteshn/vit_fruit_ripeness_classifier"
74
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
75
-
76
- # Load model components
77
- processor = AutoImageProcessor.from_pretrained(REPO_ID, subfolder="vit_fruit_ripeness_updated/processor", use_fast=True)
78
- backbone = ViTModel.from_pretrained(REPO_ID, subfolder="vit_fruit_ripeness_updated/vit_backbone")
79
- backbone.to(DEVICE)
80
- backbone.eval()
81
-
82
- # Load sklearn artifacts
83
- scaler_path = hf_hub_download(REPO_ID, "scaler.joblib")
84
- clf_path = hf_hub_download(REPO_ID, "logistic_model.joblib")
85
- metadata_path = hf_hub_download(REPO_ID, "metadata.json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  scaler = joblib.load(scaler_path)
88
  clf = joblib.load(clf_path)
89
  metadata = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
90
  classes = metadata["classes"]
91
 
 
92
  def predict(image_path: str):
93
  """Predict ripeness condition for a single image."""
94
  img = Image.open(image_path).convert("RGB")
@@ -103,20 +168,27 @@ def predict(image_path: str):
103
  feat = pooled.cpu().numpy()
104
 
105
  feat_scaled = scaler.transform(feat)
106
- probs = clf.predict_proba(feat_scaled)[0]
 
 
 
 
 
 
 
 
107
  idx = int(np.argmax(probs))
108
-
109
- return classes[idx], float(probs[idx]), {
110
- classes[i]: float(probs[i]) for i in range(len(classes))
111
- }
112
 
113
- # Example usage
114
  if __name__ == "__main__":
115
- label, prob, all_probs = predict("my_apple.jpg")
 
116
  print(f"Prediction: {label} ({prob*100:.2f}%)")
117
- print("\nTop 5 probabilities:")
118
- for cls, p in sorted(all_probs.items(), key=lambda x: -x[1])[:5]:
119
  print(f" {cls}: {p*100:.2f}%")
 
120
  ```
121
 
122
  ### Batch Prediction
 
60
  ### Single Image Inference
61
 
62
  ```python
63
+
64
  import json
65
  import joblib
66
  from pathlib import Path
67
  from PIL import Image
68
  import torch
69
  import numpy as np
70
+ from huggingface_hub import hf_hub_download, HfApi
71
  from transformers import AutoImageProcessor, ViTModel
72
+ import warnings
73
 
74
+ # ----------------- CONFIG -----------------
75
+ REPO_ID = "Meeteshn/vit_fruit_ripeness_classifier"
76
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
77
+ NESTED_FOLDER = "vit_fruit_ripeness_updated" # your repo uses this nested folder
78
+ TOP_K = 5
79
+ # ------------------------------------------
80
+
81
+ def hf_download_try(repo_id: str, filename: str, nested_folder: str = NESTED_FOLDER):
82
+ """
83
+ Try to download `filename` from repo root, then from nested_folder/filename.
84
+ Returns local path to downloaded file or raises an informative error.
85
+ """
86
+ candidates = [filename, f"{nested_folder}/{filename}"]
87
+ last_exc = None
88
+ for f in candidates:
89
+ try:
90
+ print(f"Trying to download '{f}' from '{repo_id}'...")
91
+ path = hf_hub_download(repo_id=repo_id, filename=f)
92
+ print("Downloaded:", path)
93
+ return path
94
+ except Exception as e:
95
+ print(f"Not found at '{f}': {e}")
96
+ last_exc = e
97
+ raise RuntimeError(f"Could not download '{filename}' from repo '{repo_id}'. Last error: {last_exc}")
98
+
99
+ def load_processor_and_backbone(repo_id: str, nested_folder: str = NESTED_FOLDER, device: str = DEVICE):
100
+ """
101
+ Try several likely subfolder locations for processor/backbone.
102
+ Returns (processor, backbone).
103
+ """
104
+ # candidate subfolders for processor
105
+ proc_candidates = [
106
+ "processor",
107
+ f"{nested_folder}/processor",
108
+ "", # no subfolder (root)
109
+ ]
110
+ last_exc = None
111
+ for sub in proc_candidates:
112
+ try:
113
+ if sub == "":
114
+ print(f"Trying AutoImageProcessor.from_pretrained('{repo_id}')")
115
+ processor = AutoImageProcessor.from_pretrained(repo_id, use_fast=True)
116
+ else:
117
+ print(f"Trying AutoImageProcessor.from_pretrained('{repo_id}', subfolder='{sub}')")
118
+ processor = AutoImageProcessor.from_pretrained(repo_id, subfolder=sub, use_fast=True)
119
+ # now try backbone with matching guessed subfolder
120
+ backbone_sub = sub.replace("processor", "vit_backbone") if sub and "processor" in sub else ("vit_backbone" if sub == "" else f"{nested_folder}/vit_backbone")
121
+ try:
122
+ print(f"Trying ViTModel.from_pretrained('{repo_id}', subfolder='{backbone_sub}')")
123
+ backbone = ViTModel.from_pretrained(repo_id, subfolder=backbone_sub)
124
+ except Exception as e_backbone:
125
+ # final fallback: try root vit_backbone
126
+ print(f"Backbone attempt failed for sub='{backbone_sub}': {e_backbone}. Trying root 'vit_backbone'.")
127
+ backbone = ViTModel.from_pretrained(repo_id, subfolder="vit_backbone")
128
+ backbone.to(device)
129
+ backbone.eval()
130
+ print(f"Loaded processor/backbone from subfolder='{sub or 'root'}'")
131
+ return processor, backbone
132
+ except Exception as e:
133
+ print(f"Processor load failed for sub='{sub}': {e}")
134
+ last_exc = e
135
+ # ultimate fallback: official ViT from hub
136
+ warnings.warn("Could not load processor/backbone from repo; falling back to official 'google/vit-base-patch16-224'.")
137
+ processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224", use_fast=True)
138
+ backbone = ViTModel.from_pretrained("google/vit-base-patch16-224")
139
+ backbone.to(device)
140
+ backbone.eval()
141
+ return processor, backbone
142
+
143
+ # ----------------- Load assets (robust) -----------------
144
+ processor, backbone = load_processor_and_backbone(REPO_ID, nested_folder=NESTED_FOLDER, device=DEVICE)
145
+
146
+ # Download sklearn artifacts (try root then nested)
147
+ scaler_path = hf_download_try(REPO_ID, "scaler.joblib", nested_folder=NESTED_FOLDER)
148
+ clf_path = hf_download_try(REPO_ID, "logistic_model.joblib", nested_folder=NESTED_FOLDER)
149
+ metadata_path = hf_download_try(REPO_ID, "metadata.json", nested_folder=NESTED_FOLDER)
150
 
151
  scaler = joblib.load(scaler_path)
152
  clf = joblib.load(clf_path)
153
  metadata = json.loads(Path(metadata_path).read_text(encoding="utf-8"))
154
  classes = metadata["classes"]
155
 
156
+ # ----------------- Prediction function -----------------
157
  def predict(image_path: str):
158
  """Predict ripeness condition for a single image."""
159
  img = Image.open(image_path).convert("RGB")
 
168
  feat = pooled.cpu().numpy()
169
 
170
  feat_scaled = scaler.transform(feat)
171
+ # get probabilities (works for sklearn logistic / classifiers with predict_proba)
172
+ if hasattr(clf, "predict_proba"):
173
+ probs = clf.predict_proba(feat_scaled)[0]
174
+ else:
175
+ # fallback for classifiers without predict_proba
176
+ dec = clf.decision_function(feat_scaled)[0]
177
+ exp = np.exp(dec - np.max(dec))
178
+ probs = exp / exp.sum()
179
+
180
  idx = int(np.argmax(probs))
181
+ return classes[idx], float(probs[idx]), {classes[i]: float(probs[i]) for i in range(len(classes))}
 
 
 
182
 
183
+ # ----------------- Example usage -----------------
184
  if __name__ == "__main__":
185
+ sample_image = "my_apple.jpg" # change as needed
186
+ label, prob, all_probs = predict(sample_image)
187
  print(f"Prediction: {label} ({prob*100:.2f}%)")
188
+ print("\nTop probabilities:")
189
+ for cls, p in sorted(all_probs.items(), key=lambda x: -x[1])[:TOP_K]:
190
  print(f" {cls}: {p*100:.2f}%")
191
+
192
  ```
193
 
194
  ### Batch Prediction