Patrick Daniel commited on
Commit
f3969d0
·
1 Parent(s): 7cea9f7

Fixed inference

Browse files
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. .cache/.DS_Store +0 -0
  3. app.py +46 -17
  4. label_names.json +91 -0
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
.cache/.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -5,15 +5,39 @@ from PIL import Image
5
  import requests
6
  from io import BytesIO
7
  import os
 
 
 
 
8
 
9
- # Authenticate with Hugging Face Hub for private model access
10
- from huggingface_hub import login
11
- login(token=os.environ.get("HF_TOKEN")) # Set this in your Space's Secrets tab
12
 
13
- # Load model and processor
14
- model = ViTForImageClassification.from_pretrained("patcdaniel/phytoViT_508k_20250611")
15
- processor = ViTImageProcessor.from_pretrained("patcdaniel/phytoViT_508k_20250611")
16
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Use GPU if available
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -22,19 +46,27 @@ model.to(device)
22
  # Inference function
23
  def predict(image):
24
  try:
25
- image = image.convert("RGB")
26
- inputs = processor(images=image, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
27
 
28
  with torch.no_grad():
29
- logits = model(**inputs).logits
30
  probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
31
 
32
- topk = torch.topk(probs, k=2)
33
  top_indices = topk.indices.tolist()
34
  top_scores = topk.values.tolist()
35
 
36
- id2label = model.config.id2label
37
- top_labels = [id2label[str(i)] for i in top_indices]
38
 
39
  return {label: round(score, 4) for label, score in zip(top_labels, top_scores)}
40
 
@@ -62,10 +94,7 @@ with gr.Blocks() as demo:
62
  image_input = gr.Image(type="pil", label="Upload Image")
63
  url_input = gr.Textbox(label="...or paste image URL")
64
  predict_btn = gr.Button("Classify")
65
-
66
- with gr.Column():
67
- image_output = gr.Image(label="Input Image")
68
- label_output = gr.Label(label="Top 2 Predictions")
69
 
70
  predict_btn.click(fn=predict, inputs=image_input, outputs=label_output)
71
  url_input.change(fn=classify_from_url, inputs=url_input, outputs=label_output)
 
5
  import requests
6
  from io import BytesIO
7
  import os
8
+ from safetensors.torch import load_file
9
+ from huggingface_hub import hf_hub_download
10
+ from torchvision import transforms
11
+ import json
12
 
 
 
 
13
 
14
+ # Download the file from your model repo (replace with your actual token if private)
15
+ model_path = hf_hub_download(
16
+ repo_id="patcdaniel/phytoViT_508k_20250611",
17
+ filename="model.safetensors",
18
+ token=os.environ.get("HF_TOKEN") # omit this line if public
19
+ )
20
+ state_dict = load_file(model_path)
21
+
22
+ model = ViTForImageClassification.from_pretrained(
23
+ "google/vit-base-patch16-224-in21k",
24
+ num_labels=95 # this must match your training
25
+ )
26
+ model.load_state_dict(state_dict)
27
+
28
+ model_path = hf_hub_download(
29
+ repo_id="patcdaniel/phytoViT_508k_20250611",
30
+ filename="label_names.json",
31
+ token=os.environ.get("HF_TOKEN"),
32
+ local_dir="."
33
+ )
34
+
35
+ # Load class label dictionary (label -> index)
36
+ with open(model_path, "r") as f:
37
+ id2label = {int(k): v for k, v in json.load(f).items()}
38
+
39
+ # Convert to id -> label
40
+ # id2label = {v: k for k, v in label2id.items()}
41
 
42
  # Use GPU if available
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
46
  # Inference function
47
  def predict(image):
48
  try:
49
+ transform = transforms.Compose([
50
+ transforms.Resize(256),
51
+ transforms.CenterCrop(224),
52
+ transforms.ToTensor(), # Converts to [0, 1] range
53
+ transforms.Normalize(
54
+ mean=[0.485, 0.456, 0.406],
55
+ std=[0.229, 0.224, 0.225]
56
+ )
57
+ ])
58
+
59
+ pixel_values = transform(image).unsqueeze(0).to(device)
60
 
61
  with torch.no_grad():
62
+ logits = model(pixel_values).logits
63
  probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
64
 
65
+ topk = torch.topk(probs, k=5)
66
  top_indices = topk.indices.tolist()
67
  top_scores = topk.values.tolist()
68
 
69
+ top_labels = [id2label[i] for i in top_indices]
 
70
 
71
  return {label: round(score, 4) for label, score in zip(top_labels, top_scores)}
72
 
 
94
  image_input = gr.Image(type="pil", label="Upload Image")
95
  url_input = gr.Textbox(label="...or paste image URL")
96
  predict_btn = gr.Button("Classify")
97
+ label_output = gr.Label(label="Top 5 Predictions")
 
 
 
98
 
99
  predict_btn.click(fn=predict, inputs=image_input, outputs=label_output)
100
  url_input.change(fn=classify_from_url, inputs=url_input, outputs=label_output)
label_names.json ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "Akashiwo",
3
+ "1": "Alexandrium",
4
+ "2": "Amylax_Gonyaulax_Protoceratium",
5
+ "3": "Asterionellopsis",
6
+ "4": "Asterionellopsis_chain",
7
+ "5": "Asteromphalus",
8
+ "6": "Bad_Beads",
9
+ "7": "Bad_blurred",
10
+ "8": "Bad_mixed_phyto",
11
+ "9": "Bad_setae",
12
+ "10": "Centric",
13
+ "11": "Centric_fuzzy",
14
+ "12": "Ceratium_divaricatum",
15
+ "13": "Ceratium_furca",
16
+ "14": "Ceratium_lineatum",
17
+ "15": "Chaetoceros",
18
+ "16": "Ciliate_cutoff",
19
+ "17": "Ciliate_large",
20
+ "18": "Ciliate_large_2",
21
+ "19": "Ciliate_other_morpho_1",
22
+ "20": "Clusterflagellate_morpho_1",
23
+ "21": "Clusterflagellate_morpho_2",
24
+ "22": "Corethron",
25
+ "23": "Cryptophyte",
26
+ "24": "Cylindrotheca_Nitzschia",
27
+ "25": "Detonula_Cerataulina_Lauderia",
28
+ "26": "Detritus",
29
+ "27": "Detritus_infection",
30
+ "28": "Dictyocha",
31
+ "29": "Dinoflagellate_morpho_1",
32
+ "30": "Dinoflagellate_morpho_2",
33
+ "31": "Dinoflagellate_morpho_3",
34
+ "32": "Dinophysis",
35
+ "33": "Ditylum",
36
+ "34": "Entomoneis",
37
+ "35": "Eucampia",
38
+ "36": "Euglenoid",
39
+ "37": "Flagellate_morpho_1",
40
+ "38": "Flagellate_morpho_2",
41
+ "39": "Flagellate_morpho_3",
42
+ "40": "Flagellate_nano_1",
43
+ "41": "Flagellate_nano_2",
44
+ "42": "Fragilariopsis",
45
+ "43": "Guinardia_Dactyliosolen",
46
+ "44": "Gymnodinium",
47
+ "45": "Gyrodinium",
48
+ "46": "Gyrosigma",
49
+ "47": "Haptophyte_prymnesium",
50
+ "48": "Hemiaulus",
51
+ "49": "Hemiselmis",
52
+ "50": "Heterocapsa_morpho_1",
53
+ "51": "Heterocapsa_morpho_2",
54
+ "52": "Heterosigma_akashiwo",
55
+ "53": "Laboea",
56
+ "54": "Leptocylindrus",
57
+ "55": "Lingulodinium",
58
+ "56": "Margalefidinium",
59
+ "57": "Mesodinium",
60
+ "58": "Nano_cluster",
61
+ "59": "Nano_p_white",
62
+ "60": "Pennate_med",
63
+ "61": "Pennate_morpho_1",
64
+ "62": "Pennate_short",
65
+ "63": "Pennate_wide",
66
+ "64": "Peridinium",
67
+ "65": "Phaeocystis",
68
+ "66": "Pleurosigma",
69
+ "67": "Polykrikos",
70
+ "68": "Proboscia",
71
+ "69": "Prorocentrum_narrow",
72
+ "70": "Prorocentrum_wide",
73
+ "71": "Pseudo-nitzschia",
74
+ "72": "Pyramimonas",
75
+ "73": "Rhizosolenia",
76
+ "74": "Scrippsiella",
77
+ "75": "Skeleonema",
78
+ "76": "Skeletonema",
79
+ "77": "Spiky_pacman_circular",
80
+ "78": "Stombidinium_morpho_1",
81
+ "79": "Strombidium_morpho_2",
82
+ "80": "Thalassionema",
83
+ "81": "Thalassiosira",
84
+ "82": "Tiarina",
85
+ "83": "Tintinnid",
86
+ "84": "Tontonia",
87
+ "85": "Torodinium",
88
+ "86": "Tropidoneis",
89
+ "87": "Unknown_morpho_1",
90
+ "88": "Vicicitus"
91
+ }