kostya-cholak commited on
Commit
bef9145
·
1 Parent(s): 6bae95f

fix: restore app

Browse files
Files changed (1) hide show
  1. app.py +119 -10
app.py CHANGED
@@ -1,18 +1,127 @@
 
 
 
 
1
  import streamlit as st
2
 
 
 
3
 
4
- def main():
5
- st.title('Example App')
6
 
7
- # Slider for input
8
- number = st.slider('Select a number', 0, 100, 25)
 
 
9
 
10
- # Calculate the square of the number
11
- square = number ** 2
12
 
13
- # Display the result
14
- st.write('The square of', number, 'is', square)
15
 
 
 
 
16
 
17
- if __name__ == "__main__":
18
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import tempfile
4
+ import cv2
5
  import streamlit as st
6
 
7
+ from ultralytics import YOLO
8
+ from huggingface_hub import hf_hub_url, cached_download
9
 
 
 
10
 
11
+ @st.cache_resource
12
+ def load_model():
13
+ repo_id = "BreakIntoData/cv_workshop"
14
+ filename = "soccer_ball.pt"
15
 
16
+ # Create a URL for the model file on the Hugging Face Hub
17
+ model_url = hf_hub_url(repo_id, filename)
18
 
19
+ # Download the model file from the Hub and cache it locally
20
+ cached_model_path = cached_download(model_url)
21
 
22
+ # Rename the file to have a .pt extension
23
+ new_cached_model_path = f"{cached_model_path}.pt"
24
+ os.rename(cached_model_path, new_cached_model_path)
25
 
26
+ print(f"Downloaded model to {new_cached_model_path}")
27
+
28
+ # Load the model using YOLO from the cached model file
29
+ return YOLO(new_cached_model_path)
30
+
31
+
32
+ def process_video(video_path, output_path):
33
+ cap = cv2.VideoCapture(video_path)
34
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
37
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
38
+
39
+ fourcc = cv2.VideoWriter_fourcc(*'avc1')
40
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
41
+
42
+ progress_text = "Processing video... Please wait."
43
+ progress_bar = st.progress(0)
44
+ status_text = st.empty()
45
+ time_text = st.empty()
46
+
47
+ start_time = time.time()
48
+
49
+ for frame_num in range(total_frames):
50
+ ret, frame = cap.read()
51
+ if not ret:
52
+ break
53
+
54
+ results = model(frame)
55
+ annotated_frame = results[0].plot()
56
+ out.write(annotated_frame)
57
+
58
+ # Update progress
59
+ progress = (frame_num + 1) / total_frames
60
+ elapsed_time = time.time() - start_time
61
+ estimated_total_time = elapsed_time / progress
62
+ remaining_time = estimated_total_time - elapsed_time
63
+
64
+ progress_bar.progress(progress)
65
+ status_text.text(f"Processing frame {frame_num+1}/{total_frames}")
66
+ time_text.text(f"Elapsed time: {elapsed_time:.2f}s | Estimated time remaining: {remaining_time:.2f}s")
67
+
68
+
69
+ cap.release()
70
+ out.release()
71
+ progress_bar.empty()
72
+ status_text.text(f"Processed {total_frames} frames")
73
+ time_text.text(f"Total time: {time.time() - start_time:.2f}s")
74
+
75
+
76
+ model = load_model()
77
+ st.title("Soccer Ball Detection App")
78
+
79
+ # Sidebar for options
80
+ st.sidebar.header("Options")
81
+ video_option = st.sidebar.radio("Choose video source:", ("Use preset video", "Upload video"))
82
+
83
+ if video_option == "Upload video":
84
+ uploaded_file = st.sidebar.file_uploader("Choose a video file", type=["mp4", "avi", "mov"])
85
+ if uploaded_file is not None:
86
+ tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
87
+ tfile.write(uploaded_file.read())
88
+ video_path = tfile.name
89
+ else:
90
+ preset_videos = {
91
+ "Ronaldo": "preset_videos/Ronaldo.mp4",
92
+ "Sancho": "preset_videos/CityUtdR video.mp4",
93
+ "Messi": "preset_videos/Messi.mp4",
94
+ }
95
+ selected_video = st.sidebar.selectbox("Select a preset video", list(preset_videos.keys()))
96
+ video_path = preset_videos[selected_video]
97
+
98
+ if 'video_path' in locals():
99
+ st.header("Original Video")
100
+ st.video(video_path)
101
+
102
+ if st.button("Detect"):
103
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
104
+
105
+ with st.spinner("Processing video..."):
106
+ process_video(video_path, temp_file.name)
107
+
108
+ st.success("Video processing complete!")
109
+ st.header("Processed Video")
110
+ st.video(temp_file.name)
111
+
112
+ # Add download link
113
+ with open(temp_file.name, "rb") as file:
114
+ btn = st.download_button(
115
+ label="Download Video",
116
+ data=file,
117
+ file_name="processed_video.mp4",
118
+ mime="video/mp4"
119
+ )
120
+
121
+ # # Clean up temporary files
122
+ # os.unlink(temp_file.name)
123
+ if video_option == "Upload video":
124
+ os.unlink(video_path)
125
+
126
+ st.sidebar.markdown("---")
127
+ st.sidebar.write("Developed with ❤️ by Break Into Data")