Nonabzbssbbsbs commited on
Commit
f1a734d
·
verified ·
1 Parent(s): df785e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
3
+ from PIL import Image
4
+ import torch
5
+
6
+ # *******************************************************************
7
+ # ЕҢ ТАЗА ЖӘНЕ ҚУАТТЫ МОДЕЛЬ ID-І (90%+ Accuracy, таза PyTorch)
8
+ # *******************************************************************
9
+ MODEL_ID = "keremberke/vit-base-patch16-224-full-empty-trash-bin"
10
+ CLASS_NAMES = ['Empty', 'Full']
11
+
12
+ try:
13
+ feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_ID)
14
+ model = ViTForImageClassification.from_pretrained(MODEL_ID)
15
+ MODEL_LOADED = True
16
+ # Модельді 2 классқа бейімдейміз (қате болмауы үшін)
17
+ if model.config.id2label:
18
+ CLASS_NAMES = [model.config.id2label[i] for i in model.config.id2label]
19
+
20
+ except Exception as e:
21
+ print(f"ERROR: Model loading failed: {e}")
22
+ MODEL_LOADED = False
23
+
24
+ def classify_trash_bin(image):
25
+ if not MODEL_LOADED:
26
+ return {"Error": 1.0, "Check Logs": 0.0}
27
+
28
+ if image is None:
29
+ return {CLASS_NAMES[0]: 0.5, CLASS_NAMES[1]: 0.5}
30
+
31
+ try:
32
+ img = Image.fromarray(image).convert("RGB")
33
+ inputs = feature_extractor(images=img, return_tensors="pt")
34
+
35
+ with torch.no_grad():
36
+ outputs = model(**inputs)
37
+
38
+ logits = outputs.logits
39
+ probabilities = torch.softmax(logits, dim=1).squeeze().tolist()
40
+
41
+ # Тек алғашқы 2 класты қайтару
42
+ if len(probabilities) > 2:
43
+ probabilities = probabilities[:2]
44
+
45
+ results = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))}
46
+ return results
47
+
48
+ except Exception as e:
49
+ return {"Error": 1.0, "Check Logs": 0.0}
50
+
51
+ # Gradio интерфейсін құру
52
+ iface = gr.Interface(
53
+ fn=classify_trash_bin,
54
+ inputs=gr.Image(type="numpy", label="SmartTrachAI Input"),
55
+ outputs=gr.Label(num_top_classes=2, label="Prediction"),
56
+ title="SmartTrachAI",
57
+ description="Automated Trash Bin Status Detector."
58
+ )
59
+
60
+ iface.launch()