GwFirman commited on
Commit
7df3a73
·
verified ·
1 Parent(s): 4e1b71b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -32
app.py CHANGED
@@ -1,40 +1,89 @@
 
1
  import torch
2
  from PIL import Image
3
- import gradio as gr
 
 
4
 
5
- # Load trained model
6
- model = torch.load("best.pt", map_location=torch.device("cpu"))
7
- model.eval()
8
 
9
- # Function to process image
10
- def predict(image):
11
- # Convert to PIL image if needed
12
- if not isinstance(image, Image.Image):
13
- image = Image.fromarray(image)
14
-
15
- # Preprocess image (adjust as per your model's requirements)
16
- results = model([image]) # Assuming YOLOv11 inference works like this
17
- detections = results.xyxy[0].numpy() # Extract bounding boxes, scores, etc.
18
 
19
- # Draw boxes on image
20
- for box in detections:
21
- x1, y1, x2, y2, conf, cls = box
22
- label = f"Class {int(cls)}: {conf:.2f}"
23
- draw = ImageDraw.Draw(image)
24
- draw.rectangle([(x1, y1), (x2, y2)], outline="red", width=3)
25
- draw.text((x1, y1), label, fill="red")
26
-
27
  return image
28
 
29
- # Gradio interface
30
- interface = gr.Interface(
31
- fn=predict,
32
- inputs=gr.Image(type="pil"),
33
- outputs="image",
34
- title="YOLOv11 Object Detection",
35
- description="Upload an image to detect objects using YOLOv11.",
36
- )
 
 
 
 
 
 
 
 
 
37
 
38
- # Launch app
39
- if __name__ == "__main__":
40
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import torch
3
  from PIL import Image
4
+ import numpy as np
5
+ import cv2
6
+ import matplotlib.pyplot as plt
7
 
8
+ # Konfigurasi model
9
+ MODEL_PATH = "model.pth" # Pastikan model ada di direktori ini
10
+ CLASS_NAMES = ["bag", "person-static-object"] # Sesuaikan dengan kelas Anda
11
 
12
+ # Fungsi untuk memuat model
13
+ @st.cache_resource
14
+ def load_model():
15
+ model = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
16
+ model.eval() # Mode evaluasi
17
+ return model
 
 
 
18
 
19
+ # Preprocessing gambar
20
+ def preprocess(image):
21
+ # Resize dan konversi ke tensor
22
+ image = np.array(image)
23
+ image = cv2.resize(image, (640, 640)) # Sesuaikan dengan input size model
24
+ image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
25
+ image = image.unsqueeze(0) # Tambahkan batch dimension
 
26
  return image
27
 
28
+ # Postprocessing hasil prediksi
29
+ def postprocess(prediction, confidence_threshold=0.5):
30
+ # Contoh untuk deteksi objek YOLO-style (sesuaikan dengan output model Anda)
31
+ boxes = prediction[0]["boxes"].detach().numpy()
32
+ scores = prediction[0]["scores"].detach().numpy()
33
+ labels = prediction[0]["labels"].detach().numpy()
34
+
35
+ # Filter berdasarkan confidence threshold
36
+ keep = scores >= confidence_threshold
37
+ return boxes[keep], scores[keep], labels[keep]
38
+
39
+ # Antarmuka Streamlit
40
+ st.title("Deteksi Objek Tertinggal 👜👤")
41
+ st.write("Upload gambar untuk mendeteksi objek 'bag' atau 'person-static-object'")
42
+
43
+ # Upload gambar
44
+ uploaded_file = st.file_uploader("Pilih gambar...", type=["jpg", "png", "jpeg"])
45
 
46
+ if uploaded_file is not None:
47
+ # Memuat gambar
48
+ image = Image.open(uploaded_file).convert("RGB")
49
+ st.image(image, caption="Gambar Input", use_column_width=True)
50
+
51
+ # Proses deteksi
52
+ if st.button("Deteksi Objek"):
53
+ with st.spinner("Memproses..."):
54
+ try:
55
+ # Memuat model
56
+ model = load_model()
57
+
58
+ # Preprocessing
59
+ input_tensor = preprocess(image)
60
+
61
+ # Prediksi
62
+ with torch.no_grad():
63
+ prediction = model(input_tensor)
64
+
65
+ # Postprocessing
66
+ boxes, scores, labels = postprocess(prediction)
67
+
68
+ # Visualisasi hasil
69
+ fig, ax = plt.subplots(1, figsize=(12, 6))
70
+ ax.imshow(image)
71
+
72
+ for box, score, label in zip(boxes, scores, labels):
73
+ x1, y1, x2, y2 = box
74
+ rect = plt.Rectangle(
75
+ (x1, y1), x2 - x1, y2 - y1,
76
+ linewidth=2, edgecolor="lime", facecolor="none"
77
+ )
78
+ ax.add_patch(rect)
79
+ ax.text(
80
+ x1, y1 - 5,
81
+ f"{CLASS_NAMES[label]}: {score:.2f}",
82
+ color="lime", fontsize=10, backgroundcolor="black"
83
+ )
84
+
85
+ st.pyplot(fig)
86
+ st.success("Selesai!")
87
+
88
+ except Exception as e:
89
+ st.error(f"Error: {str(e)}")