masakljun commited on
Commit
1cdf7ce
·
1 Parent(s): 5e59d0a

First commit of LightlyTrain app

Browse files
Files changed (2) hide show
  1. app.py +110 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import supervision as sv
4
+ from PIL import Image
5
+ import lightly_train # Ensure this matches the installed package name
6
+
7
+ # --- CONFIGURATION ---
8
+ # We use a default LightlyTrain model so the Space works immediately.
9
+ DEFAULT_MODEL_NAME = "dinov3/vitt16-ltdetr-coco"
10
+
11
+ # --- HELPER FUNCTIONS ---
12
+ def load_lightly_model(model_name):
13
+ print(f"Loading model: {model_name}...")
14
+ # This automatically downloads the pretrained model from Lightly
15
+ return lightly_train.load_model(model_name)
16
+
17
+ # Initialize model once at startup to save time
18
+ model = load_lightly_model(DEFAULT_MODEL_NAME)
19
+
20
+ def predict_and_annotate(image, confidence_threshold, model_name):
21
+ # 1. Run Prediction using LightlyTrain
22
+ # LightlyTrain's predict method handles PIL images directly
23
+ results = model.predict(image)
24
+
25
+ # LightlyTrain returns a dictionary: {'bboxes': Tensor, 'labels': Tensor, 'scores': Tensor}
26
+ # We move tensors to CPU and convert to numpy for Supervision
27
+ boxes = results['bboxes'].cpu().numpy()
28
+ labels = results['labels'].cpu().numpy()
29
+ scores = results['scores'].cpu().numpy()
30
+
31
+ # 2. Filter by Confidence
32
+ valid_indices = scores > confidence_threshold
33
+ boxes = boxes[valid_indices]
34
+ labels = labels[valid_indices]
35
+ scores = scores[valid_indices]
36
+
37
+ # 3. Convert to Supervision Detections format
38
+ detections = sv.Detections(
39
+ xyxy=boxes,
40
+ confidence=scores,
41
+ class_id=labels
42
+ )
43
+
44
+ # 4. Annotate the Image
45
+ box_annotator = sv.BoxAnnotator()
46
+ label_annotator = sv.LabelAnnotator()
47
+
48
+ # Create label text (e.g., "Class: 0 0.85")
49
+ # Note: If you have a class names list (like COCO_CLASSES), you can map IDs to names here.
50
+ generated_labels = [
51
+ f"Class {class_id} {confidence:.2f}"
52
+ for class_id, confidence in zip(detections.class_id, detections.confidence)
53
+ ]
54
+
55
+ annotated_image = image.copy()
56
+ annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
57
+ annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections, labels=generated_labels)
58
+
59
+ return annotated_image
60
+
61
+ # --- GRADIO UI ---
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown("# LightlyTrain Object Detection Demo 🚀")
64
+ gr.Markdown("This demo uses **LightlyTrain** with a **DINOv3** backbone to detect objects.")
65
+
66
+ with gr.Row():
67
+ with gr.Column():
68
+ input_img = gr.Image(type="pil", label="Input Image")
69
+
70
+ conf_slider = gr.Slider(
71
+ minimum=0.0,
72
+ maximum=1.0,
73
+ value=0.3,
74
+ step=0.05,
75
+ label="Confidence Threshold"
76
+ )
77
+
78
+ # Dropdown for model selection (currently just one default)
79
+ model_selector = gr.Dropdown(
80
+ choices=[DEFAULT_MODEL_NAME],
81
+ value=DEFAULT_MODEL_NAME,
82
+ label="Model Checkpoint"
83
+ )
84
+
85
+ run_btn = gr.Button("Run Detection", variant="primary")
86
+
87
+ with gr.Column():
88
+ output_img = gr.Image(label="Annotated Result")
89
+
90
+ # Connect the button to the function
91
+ run_btn.click(
92
+ fn=predict_and_annotate,
93
+ inputs=[input_img, conf_slider, model_selector],
94
+ outputs=output_img
95
+ )
96
+
97
+ # Example images for quick testing
98
+ gr.Examples(
99
+ examples=[
100
+ ["https://media.roboflow.com/notebooks/examples/dog-2.jpeg", 0.3, DEFAULT_MODEL_NAME],
101
+ ["https://media.roboflow.com/supervision/image-examples/vehicles.png", 0.3, DEFAULT_MODEL_NAME]
102
+ ],
103
+ inputs=[input_img, conf_slider, model_selector],
104
+ outputs=output_img,
105
+ fn=predict_and_annotate,
106
+ cache_examples=True,
107
+ )
108
+
109
+ if __name__ == "__main__":
110
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ lightly-train
2
+ supervision
3
+ gradio
4
+ torch
5
+ torchvision
6
+ numpy
7
+ Pillow