Mohansai2004 commited on
Commit
1cfb24a
·
verified ·
1 Parent(s): d2be60c

Update app/model.py

Browse files
Files changed (1) hide show
  1. app/model.py +5 -5
app/model.py CHANGED
@@ -1,24 +1,24 @@
1
  from transformers import ViTImageProcessor, ViTForImageClassification
2
- from PIL import Image
3
  import torch
 
4
 
5
  MODEL_NAME = "google/vit-base-patch16-224"
6
 
7
- # Load processor and model once at startup
8
  processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
9
  model = ViTForImageClassification.from_pretrained(MODEL_NAME)
10
 
 
11
  def analyze_image(image: Image.Image):
12
  inputs = processor(images=image, return_tensors="pt")
13
  with torch.no_grad():
14
  outputs = model(**inputs)
15
-
16
  logits = outputs.logits
17
  predicted_class_idx = logits.argmax(-1).item()
18
  label = model.config.id2label[predicted_class_idx]
19
- score = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
20
 
21
  return {
22
  "label": label,
23
- "confidence": round(score, 3)
24
  }
 
1
  from transformers import ViTImageProcessor, ViTForImageClassification
 
2
  import torch
3
+ from PIL import Image
4
 
5
  MODEL_NAME = "google/vit-base-patch16-224"
6
 
7
+ # Load once at startup
8
  processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
9
  model = ViTForImageClassification.from_pretrained(MODEL_NAME)
10
 
11
+
12
  def analyze_image(image: Image.Image):
13
  inputs = processor(images=image, return_tensors="pt")
14
  with torch.no_grad():
15
  outputs = model(**inputs)
 
16
  logits = outputs.logits
17
  predicted_class_idx = logits.argmax(-1).item()
18
  label = model.config.id2label[predicted_class_idx]
19
+ confidence = torch.nn.functional.softmax(logits, dim=-1)[0, predicted_class_idx].item()
20
 
21
  return {
22
  "label": label,
23
+ "confidence": round(confidence, 4)
24
  }