File size: 1,286 Bytes
ba7b1f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os

import torch


MODEL_ID = "google/vit-base-patch16-224"


class ImageClassifierService:
    def __init__(self):
        self.pipe = None
        cpu_count = os.cpu_count() or 1
        torch.set_num_threads(max(1, min(4, cpu_count)))

    def classify(self, image):
        if image is None:
            return "", "", "Upload an image first."

        try:
            results = self._run_model(image)
            top = results[0]
            top_label = top["label"]
            formatted = self._format_results(results)
            return top_label, formatted, f"Classified image with {MODEL_ID}."
        except Exception as exc:
            return "", "", f"Image classification failed: {type(exc).__name__}: {exc}"

    def _load_pipeline(self):
        if self.pipe is not None:
            return

        from transformers import pipeline

        self.pipe = pipeline(
            "image-classification",
            model=MODEL_ID,
            device=-1,
        )

    def _run_model(self, image):
        self._load_pipeline()
        return self.pipe(image, top_k=5)

    def _format_results(self, results):
        lines = []
        for item in results:
            lines.append(f"{item['label']}: {item['score'] * 100:.1f}%")
        return "\n".join(lines)