Patrick Daniel commited on
Commit
7cea9f7
·
1 Parent(s): ff45b18

Fixed label

Browse files
Files changed (1) hide show
  1. app.py +60 -34
app.py CHANGED
@@ -1,47 +1,73 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image
4
  from transformers import ViTForImageClassification, ViTImageProcessor
 
5
  import requests
6
- import json
7
  import os
8
 
9
- hf_token = os.getenv("HF_TOKEN")
 
 
10
 
11
-
12
- # Load model and processor from Hugging Face Hub
13
- model = ViTForImageClassification.from_pretrained("patcdaniel/phytoViT_508k_20250611",token=hf_token)
14
- processor = ViTImageProcessor.from_pretrained("patcdaniel/phytoViT_508k_20250611",token=hf_token)
15
  model.eval()
16
 
17
- # Load class labels from hosted file
18
- LABELS_URL = "https://huggingface.co/patcdaniel/phytoViT_508k_20250611/resolve/main/label_names.json"
19
- headers = {"Authorization": f"Bearer {hf_token}"}
20
- class_labels = requests.get(LABELS_URL, headers=headers).json()
21
 
 
22
  def predict(image):
23
- image = image.convert("RGB")
24
- inputs = processor(images=image, return_tensors="pt")
25
- with torch.no_grad():
26
- logits = model(**inputs).logits
27
- probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
28
-
29
- # Get top 2 predictions
30
- topk = torch.topk(probs, k=2)
31
- top_scores = topk.values.tolist()
32
- top_labels = [class_labels[i] for i in topk.indices.tolist()]
33
-
34
- # Format output
35
- output = {label: round(score, 4) for label, score in zip(top_labels, top_scores)}
36
- return output
37
-
38
- # Gradio interface
39
- demo = gr.Interface(
40
- fn=predict,
41
- inputs=gr.Image(type="pil", label="Upload or Paste an Image"),
42
- outputs=gr.Label(num_top_classes=2, label="Top Predictions"),
43
- title="PhytoViT Classifier",
44
- description="Upload an IFCB phytoplankton image or paste an image URL to classify it using a ViT model trained on 508k examples."
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  demo.launch()
 
1
  import gradio as gr
2
  import torch
 
3
  from transformers import ViTForImageClassification, ViTImageProcessor
4
+ from PIL import Image
5
  import requests
6
+ from io import BytesIO
7
  import os
8
 
9
+ # Authenticate with Hugging Face Hub for private model access
10
+ from huggingface_hub import login
11
+ login(token=os.environ.get("HF_TOKEN")) # Set this in your Space's Secrets tab
12
 
13
+ # Load model and processor
14
+ model = ViTForImageClassification.from_pretrained("patcdaniel/phytoViT_508k_20250611")
15
+ processor = ViTImageProcessor.from_pretrained("patcdaniel/phytoViT_508k_20250611")
 
16
  model.eval()
17
 
18
+ # Use GPU if available
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ model.to(device)
 
21
 
22
+ # Inference function
23
  def predict(image):
24
+ try:
25
+ image = image.convert("RGB")
26
+ inputs = processor(images=image, return_tensors="pt").to(device)
27
+
28
+ with torch.no_grad():
29
+ logits = model(**inputs).logits
30
+ probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
31
+
32
+ topk = torch.topk(probs, k=2)
33
+ top_indices = topk.indices.tolist()
34
+ top_scores = topk.values.tolist()
35
+
36
+ id2label = model.config.id2label
37
+ top_labels = [id2label[str(i)] for i in top_indices]
38
+
39
+ return {label: round(score, 4) for label, score in zip(top_labels, top_scores)}
40
+
41
+ except Exception as e:
42
+ import traceback
43
+ print(traceback.format_exc())
44
+ return {"Error": str(e)}
45
+
46
+ # Optional: allow input via URL
47
+ def classify_from_url(url):
48
+ try:
49
+ response = requests.get(url)
50
+ image = Image.open(BytesIO(response.content))
51
+ return predict(image)
52
+ except Exception as e:
53
+ return {"Error": f"Could not load image from URL. {e}"}
54
+
55
+ # Gradio UI
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown("# PhytoViT - IFCB Phytoplankton Classifier")
58
+ gr.Markdown("Upload an image or paste a URL. Model: `phytoViT_508k_20250611`")
59
+
60
+ with gr.Row():
61
+ with gr.Column():
62
+ image_input = gr.Image(type="pil", label="Upload Image")
63
+ url_input = gr.Textbox(label="...or paste image URL")
64
+ predict_btn = gr.Button("Classify")
65
+
66
+ with gr.Column():
67
+ image_output = gr.Image(label="Input Image")
68
+ label_output = gr.Label(label="Top 2 Predictions")
69
+
70
+ predict_btn.click(fn=predict, inputs=image_input, outputs=label_output)
71
+ url_input.change(fn=classify_from_url, inputs=url_input, outputs=label_output)
72
 
73
  demo.launch()