Nonabzbssbbsbs commited on
Commit
952d1ad
·
verified ·
1 Parent(s): 1515998

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
3
+ from PIL import Image
4
+ import torch
5
+ import numpy as np
6
+
7
+ # ЕҢ ТАЗА ЖӘНЕ ҚУАТТЫ МОДЕЛЬ ID-І (90%+ Accuracy)
8
+ MODEL_ID = "keremberke/vit-base-patch16-224-full-empty-trash-bin"
9
+ CLASS_NAMES = ['Empty', 'Full']
10
+
11
+ try:
12
+ # Модельді жүктеу кезінде қате болмауы үшін ең қарапайым әдісті қолданамыз
13
+ feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_ID)
14
+ model = ViTForImageClassification.from_pretrained(MODEL_ID)
15
+ MODEL_LOADED = True
16
+ if model.config.id2label:
17
+ CLASS_NAMES = [model.config.id2label[i] for i in model.config.id2label]
18
+
19
+ except Exception as e:
20
+ MODEL_LOADED = False
21
+
22
+ def classify_trash_bin(image):
23
+ if not MODEL_LOADED:
24
+ # Енді "Модель жүктелмеді" дегеннің орнына "Error" деп шығады, бірақ жүктелуі керек
25
+ return {"Error": 1.0, "Check Logs": 0.0}
26
+
27
+ if image is None:
28
+ return {CLASS_NAMES[0]: 0.5, CLASS_NAMES[1]: 0.5}
29
+
30
+ try:
31
+ img = Image.fromarray(image).convert("RGB")
32
+ inputs = feature_extractor(images=img, return_tensors="pt")
33
+
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+
37
+ logits = outputs.logits
38
+ probabilities = torch.softmax(logits, dim=1).squeeze().tolist()
39
+
40
+ if len(probabilities) > 2:
41
+ probabilities = probabilities[:2]
42
+
43
+ results = {CLASS_NAMES[i]: float(probabilities[i]) for i in range(len(CLASS_NAMES))}
44
+ return results
45
+
46
+ except Exception as e:
47
+ return {"Error": 1.0, "Check Logs": 0.0}
48
+
49
+ # Gradio интерфейсін құру
50
+ iface = gr.Interface(
51
+ fn=classify_trash_bin,
52
+ inputs=gr.Image(type="numpy", label="SmartTrachAI Input"),
53
+ outputs=gr.Label(num_top_classes=2, label="Prediction"),
54
+ title="SmartTrachAI",
55
+ description="Automated Trash Bin Status Detector."
56
+ )
57
+
58
+ iface.launch()