aneeshm44 commited on
Commit
6237077
·
verified ·
1 Parent(s): 6596dd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -6,12 +6,16 @@ import torch.nn.functional as F
6
  import torchvision.models as models
7
  import tempfile
8
  import os
 
9
 
10
  if "uploaded_file" not in st.session_state:
11
  st.session_state.uploaded_file = None
 
 
12
 
13
  def reset_state():
14
  st.session_state.uploaded_file = None
 
15
  st.rerun()
16
 
17
  st.markdown(
@@ -90,11 +94,14 @@ def main():
90
  st.markdown('<div class="main">', unsafe_allow_html=True)
91
  st.title("Audio Deepfake Detector")
92
  st.write("Upload a **.wav** file to check if it's **Real** or **Fake**.")
93
-
94
- if st.session_state.uploaded_file is None:
95
- uploaded_file = st.file_uploader("Choose a .wav file", type=["wav"])
96
- if uploaded_file is not None:
97
- st.session_state.uploaded_file = uploaded_file
 
 
 
98
 
99
  if st.session_state.uploaded_file is not None:
100
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
@@ -102,6 +109,7 @@ def main():
102
  tmp_path = tmp_file.name
103
 
104
  st.audio(st.session_state.uploaded_file, format="audio/wav")
 
105
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
  checkpoint_path = "best_model.pth"
107
  try:
@@ -112,8 +120,10 @@ def main():
112
  st.error(f"Error during prediction: {e}")
113
  finally:
114
  os.remove(tmp_path)
 
115
  if st.button("Reset"):
116
  reset_state()
 
117
  st.markdown("</div>", unsafe_allow_html=True)
118
 
119
  if __name__ == "__main__":
 
6
  import torchvision.models as models
7
  import tempfile
8
  import os
9
+ import uuid
10
 
11
  if "uploaded_file" not in st.session_state:
12
  st.session_state.uploaded_file = None
13
+ if "uploader_key" not in st.session_state:
14
+ st.session_state.uploader_key = str(uuid.uuid4())
15
 
16
  def reset_state():
17
  st.session_state.uploaded_file = None
18
+ st.session_state.uploader_key = str(uuid.uuid4())
19
  st.rerun()
20
 
21
  st.markdown(
 
94
  st.markdown('<div class="main">', unsafe_allow_html=True)
95
  st.title("Audio Deepfake Detector")
96
  st.write("Upload a **.wav** file to check if it's **Real** or **Fake**.")
97
+
98
+ uploaded_file = st.file_uploader(
99
+ "Choose a .wav file",
100
+ type=["wav"],
101
+ key=st.session_state.uploader_key
102
+ )
103
+ if uploaded_file is not None:
104
+ st.session_state.uploaded_file = uploaded_file
105
 
106
  if st.session_state.uploaded_file is not None:
107
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
 
109
  tmp_path = tmp_file.name
110
 
111
  st.audio(st.session_state.uploaded_file, format="audio/wav")
112
+
113
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
  checkpoint_path = "best_model.pth"
115
  try:
 
120
  st.error(f"Error during prediction: {e}")
121
  finally:
122
  os.remove(tmp_path)
123
+
124
  if st.button("Reset"):
125
  reset_state()
126
+
127
  st.markdown("</div>", unsafe_allow_html=True)
128
 
129
  if __name__ == "__main__":