Mosensei commited on
Commit
ce33ae7
·
verified ·
1 Parent(s): ecc0e2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -179
app.py CHANGED
@@ -1,185 +1,141 @@
1
- import os
2
  import gradio as gr
 
 
3
  from ultralytics import YOLO
4
- import tempfile
5
  import cv2
6
  import numpy as np
7
- import mlflow
8
- from datetime import datetime
9
  import time
10
- import pandas as pd
11
- from collections import defaultdict
12
-
13
-
14
- DAGSHUB_REPO_OWNER = os.getenv("DAGSHUB_REPO_OWNER")
15
- DAGSHUB_REPO_NAME = os.getenv("DAGSHUB_REPO_NAME")
16
-
17
- MLFLOW_ENABLED = False
18
- if DAGSHUB_REPO_OWNER and DAGSHUB_REPO_NAME:
19
- mlflow.set_tracking_uri(f"https://dagshub.com/{DAGSHUB_REPO_OWNER}/{DAGSHUB_REPO_NAME}.mlflow" )
20
- MLFLOW_ENABLED = True
21
- print("MLflow tracking is configured for DagsHub.")
22
- else:
23
- print("DagsHub secrets not found. MLflow logging will be disabled.")
24
-
25
- os.environ.setdefault("YOLO_CONFIG_DIR", "/tmp/Ultralytics")
26
-
27
-
28
- MODEL_PATH = "best.pt"
29
- model = YOLO(MODEL_PATH)
30
-
31
-
32
- def log_image_prediction(input_img_pil, output_image_path, conf, inference_time, detections_df):
33
- if not MLFLOW_ENABLED: return
34
- try:
35
- with mlflow.start_run(run_name=f"Image_Prediction_{datetime.now().strftime('%Y%m%d-%H%M%S')}"):
36
- mlflow.log_param("confidence_threshold", conf)
37
- mlflow.log_param("prediction_type", "image")
38
- mlflow.log_metric("inference_time_seconds", inference_time)
39
- mlflow.log_metric("total_detections", len(detections_df))
40
-
41
- if not detections_df.empty:
42
- class_counts = defaultdict(int)
43
- for _, row in detections_df.iterrows():
44
- class_name = row['class_name']
45
- confidence = row['confidence']
46
- class_counts[class_name] += 1
47
- metric_name = f"detection_{class_name}_{class_counts[class_name]}"
48
- mlflow.log_metric(metric_name, confidence)
49
-
50
- input_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
51
- input_img_pil.save(input_path)
52
- mlflow.log_artifact(input_path, "input_image")
53
- mlflow.log_artifact(output_image_path, "output_image")
54
- print(f"Successfully logged image prediction.")
55
- except Exception as e:
56
- print(f"Error logging to MLflow: {e}")
57
-
58
- def log_video_prediction(input_path, output_path, conf):
59
- if not MLFLOW_ENABLED: return
60
- try:
61
- with mlflow.start_run(run_name=f"Video_Prediction_{datetime.now().strftime('%Y%m%d-%H%M%S')}"):
62
- mlflow.log_param("confidence_threshold", conf)
63
- mlflow.log_param("prediction_type", "video")
64
- mlflow.log_artifact(input_path, "input")
65
- mlflow.log_artifact(output_path, "output")
66
- print(f"Successfully logged video prediction.")
67
- except Exception as e:
68
- print(f"Error logging to MLflow: {e}")
69
-
70
-
71
- def run_image_inference(img_pil, conf=0.25):
72
- if img_pil is None: return None, 0.0, pd.DataFrame()
73
- img_np = np.array(img_pil)
74
- img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
75
- start_time = time.time()
76
-
77
- results = model(img_bgr, conf=conf, iou=0.4, verbose=False, imgsz=640)
78
- end_time = time.time()
79
- inference_time = end_time - start_time
80
- result = results[0]
81
- detections = []
82
- for box in result.boxes:
83
- class_id = int(box.cls.cpu().item())
84
- class_name = result.names[class_id]
85
- confidence = float(box.conf.cpu().item())
86
- detections.append({"class_name": class_name, "confidence": round(confidence, 4)})
87
- detections_df = pd.DataFrame(detections)
88
- annotated_img = result.plot()
89
- annotated_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
90
- out_path = tempfile.NamedTemporaryFile(suffix=".png", delete=False).name
91
- cv2.imwrite(out_path, cv2.cvtColor(annotated_rgb, cv2.COLOR_RGB2BGR))
92
- return out_path, inference_time, detections_df
93
-
94
- def run_video_inference(video_path, conf=0.25, frame_skip=2):
95
- if video_path is None: return None
96
- temp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
97
-
98
- results_generator = model(video_path, conf=conf, iou=0.4, verbose=False, stream=True, imgsz=640)
99
-
100
- try:
101
- first_result = next(results_generator)
102
- except StopIteration:
103
- return None
104
-
105
- cap = cv2.VideoCapture(video_path)
106
- fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
107
- h, w = first_result.orig_shape
108
-
109
- output_fps = fps / (frame_skip + 1) if frame_skip > -1 else fps
110
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
111
- out = cv2.VideoWriter(temp_out, fourcc, output_fps, (w, h))
112
-
113
- out.write(first_result.plot())
114
-
115
- frame_count = 0
116
- for result in results_generator:
117
- frame_count += 1
118
- if frame_skip > -1 and frame_count % (frame_skip + 1) != 0:
119
- continue
120
- annotated_frame = result.plot()
121
- out.write(annotated_frame)
122
-
123
- cap.release()
124
- out.release()
125
- return temp_out
126
-
127
- dark_css = """<style> body { background-color: #0f1724; color: #e6eef8; } .gradio-container { background-color: transparent !important; } h1 { color: #ffcc00; } .subtle { color: #9fb0c8; } .card-like { background: rgba(255,255,255,0.03); border-radius: 12px; padding: 12px; } </style>"""
128
-
129
- with gr.Blocks() as demo:
130
- gr.HTML(dark_css)
131
- gr.Markdown("# 🎯 YOLO Detection Studio — Image & Video")
132
- gr.Markdown("<div class='subtle'>Upload an image or video, then press Detect.</div>")
133
  with gr.Row():
134
- with gr.Column(scale=2):
135
- with gr.Tabs():
136
- with gr.TabItem("Image"):
137
- image_input = gr.Image(type="pil", label="Upload Image")
138
- img_conf = gr.Slider(0.0, 1.0, value=0.25, step=0.01, label="Confidence Threshold")
139
- img_detect_btn = gr.Button("🔍 Detect Image")
140
- image_output = gr.Image(label="Detected Image")
141
- with gr.TabItem("Video"):
142
- video_input = gr.Video(label="Upload Video")
143
- vid_conf = gr.Slider(0.0, 1.0, value=0.25, step=0.01, label="Confidence Threshold")
144
- frame_skip_slider = gr.Slider(-1, 10, value=2, step=1, label="Frame Skip", info="Process 1 frame every (N+1) frames. -1 to process all frames.")
145
- vid_detect_btn = gr.Button("🎬 Detect Video")
146
- video_output = gr.Video(label="Detected Video")
147
- with gr.Column(scale=1):
148
- gr.Markdown("### ⚙️ Options & Status")
149
- status = gr.Textbox(label="Status", value="Ready", interactive=False)
150
- clear_btn = gr.Button("🧹 Clear Outputs")
151
-
152
- def on_detect_image(img, conf):
153
- try:
154
- out_path, inference_time, detections_df = run_image_inference(img, conf=conf)
155
- log_image_prediction(img, out_path, conf, inference_time, detections_df)
156
- status_msg = f"Done. Inference: {inference_time:.2f}s. Detections: {len(detections_df)}."
157
- if MLFLOW_ENABLED: status_msg += " Logged to DagsHub."
158
- return out_path, status_msg
159
- except Exception as e:
160
- return None, f"Error: {e}"
161
-
162
- def on_detect_video(video_path, conf, frame_skip):
163
- try:
164
- start_time = time.time()
165
- out_path = run_video_inference(video_path, conf=conf, frame_skip=frame_skip)
166
- end_time = time.time()
167
- if out_path:
168
- log_video_prediction(video_path, out_path, conf)
169
- status_msg = f"Done — video processed in {end_time - start_time:.2f}s."
170
- if MLFLOW_ENABLED: status_msg += " Logged to DagsHub."
171
- return out_path, status_msg
172
- else:
173
- return None, "Could not process video."
174
- except Exception as e:
175
- import traceback
176
- print(traceback.format_exc())
177
- return None, f"Error: {e}"
178
-
179
- img_detect_btn.click(fn=on_detect_image, inputs=[image_input, img_conf], outputs=[image_output, status])
180
- vid_detect_btn.click(fn=on_detect_video, inputs=[video_input, vid_conf, frame_skip_slider], outputs=[video_output, status])
181
-
182
- def on_clear(): return None, "Ready", None
183
- clear_btn.click(fn=on_clear, inputs=None, outputs=[image_output, status, video_output])
184
-
185
- demo.launch(server_name="0.0.0.0", share=False)
 
 
1
  import gradio as gr
2
+ import mlflow
3
+ import dagshub
4
  from ultralytics import YOLO
5
+ from PIL import Image
6
  import cv2
7
  import numpy as np
8
+ import os
 
9
  import time
10
+ import tempfile
11
+
12
+ # ==============================
13
+ # MLflow / DagsHub Configuration
14
+ # ==============================
15
+ os.environ["MLFLOW_TRACKING_URI"] = os.getenv("MLFLOW_TRACKING_URI")
16
+ os.environ["MLFLOW_TRACKING_USERNAME"] = os.getenv("MLFLOW_TRACKING_USERNAME")
17
+ os.environ["MLFLOW_TRACKING_PASSWORD"] = os.getenv("MLFLOW_TRACKING_PASSWORD")
18
+
19
+ dagshub.init(
20
+ repo_owner="Mosensei7",
21
+ repo_name="AutonomousVehiclesDetectionDEPI",
22
+ mlflow=True
23
+ )
24
+
25
+ mlflow.set_experiment("YOLOv12_Inference")
26
+
27
+ # ==============================
28
+ # Load YOLOv12 Model
29
+ # ==============================
30
+ model = YOLO("best.pt") # YOLOv12s weights
31
+
32
+ # ==============================
33
+ # Inference Logic
34
+ # ==============================
35
+ def run_inference(media_file, media_type):
36
+ media_path = media_file.name
37
+
38
+ with mlflow.start_run(run_name=f"Inference_{int(time.time())}") as run:
39
+ mlflow.log_param("media_type", media_type)
40
+ mlflow.log_param("model", "YOLOv12s")
41
+
42
+ if media_type == "Image":
43
+ img = Image.open(media_path).convert("RGB")
44
+
45
+ results = model(np.array(img))[0]
46
+ annotated = results.plot()
47
+ output_img = Image.fromarray(annotated)
48
+
49
+ # Save temp artifacts
50
+ with tempfile.TemporaryDirectory() as tmp:
51
+ in_path = os.path.join(tmp, "input.jpg")
52
+ out_path = os.path.join(tmp, "output.jpg")
53
+
54
+ img.save(in_path)
55
+ output_img.save(out_path)
56
+
57
+ mlflow.log_artifact(in_path, "inputs")
58
+ mlflow.log_artifact(out_path, "outputs")
59
+
60
+ mlflow.log_metric("detections", len(results.boxes))
61
+
62
+ return output_img, None, run.info.run_id
63
+
64
+ else:
65
+ cap = cv2.VideoCapture(media_path)
66
+ fps = cap.get(cv2.CAP_PROP_FPS)
67
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
68
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
69
+
70
+ out_path = "annotated_output.mp4"
71
+ writer = cv2.VideoWriter(
72
+ out_path,
73
+ cv2.VideoWriter_fourcc(*"mp4v"),
74
+ fps,
75
+ (w, h)
76
+ )
77
+
78
+ frame_count = 0
79
+ total_detections = 0
80
+
81
+ while cap.isOpened():
82
+ ret, frame = cap.read()
83
+ if not ret:
84
+ break
85
+
86
+ results = model(frame)[0]
87
+ annotated = results.plot()
88
+
89
+ writer.write(annotated)
90
+ frame_count += 1
91
+ total_detections += len(results.boxes)
92
+
93
+ cap.release()
94
+ writer.release()
95
+
96
+ mlflow.log_artifact(media_path, "inputs")
97
+ mlflow.log_artifact(out_path, "outputs")
98
+ mlflow.log_metric("frames", frame_count)
99
+ mlflow.log_metric("total_detections", total_detections)
100
+
101
+ return None, out_path, run.info.run_id
102
+
103
+ # ==============================
104
+ # Futuristic UI
105
+ # ==============================
106
+ css = """
107
+ body {
108
+ background: linear-gradient(135deg, #0f0c29, #302b63, #24243e);
109
+ color: white;
110
+ font-family: 'Orbitron', sans-serif;
111
+ }
112
+ .gradio-container {
113
+ border: 2px solid cyan;
114
+ border-radius: 20px;
115
+ box-shadow: 0 0 20px cyan;
116
+ }
117
+ """
118
+
119
+ with gr.Blocks(css=css) as demo:
120
+ gr.Markdown("""
121
+ <h1 style='text-align:center;color:cyan;'>YOLOv12 Autonomous Vehicle Detection</h1>
122
+ <p style='text-align:center;'>All inferences are logged to DagsHub MLflow</p>
123
+ """)
124
+
 
 
 
 
 
 
 
 
125
  with gr.Row():
126
+ media = gr.File(label="Upload Image / Video")
127
+ media_type = gr.Radio(["Image", "Video"], value="Image")
128
+
129
+ detect = gr.Button("Run Detection")
130
+
131
+ img_out = gr.Image(label="Image Result")
132
+ vid_out = gr.Video(label="Video Result")
133
+ run_id = gr.Textbox(label="MLflow Run ID")
134
+
135
+ detect.click(
136
+ run_inference,
137
+ inputs=[media, media_type],
138
+ outputs=[img_out, vid_out, run_id]
139
+ )
140
+
141
+ demo.launch(share=True)