Arnab Das commited on
Commit
e07795f
·
1 Parent(s): 1911dce
Files changed (1) hide show
  1. app.py +20 -16
app.py CHANGED
@@ -3,36 +3,40 @@ import gradio as gr
3
  import models as MOD
4
  import process_data as PD
5
 
6
- model_master= {
7
- "SSL-AASIST (Trained on ASV-Spoof5)": {"eer_threshold":3.3330237865448,
8
- "data_process_func": "process_ssl_assist_input" ,
9
  "note": "This model is trained on ASVSpoof 2024 training data.",
10
  "model_class": "Model",
11
- "model_checkpoint":"ssl_aasist_epoch_7.pth"},
12
- "AASIST": {"eer_threshold":1.8018419742584229,
13
- "data_process_func": "process_assist_input" ,
14
- "note": "This model is trained on ASVSpoof 2024 training data."}
15
  }
16
 
17
-
18
- model = MOD.Model(None,"cpu")
19
- model.load_state_dict(torch.load("ssl_aasist_epoch_7.pth",map_location="cpu"))
20
  loaded_model = "SSL-AASIST (Trained on ASV-Spoof5)"
21
  print("model loaded")
22
 
 
23
  def process(file, type):
 
 
24
  global model
25
  global loaded_model
26
  inp = getattr(PD, model_master[type]["data_process_func"])(file)
27
- print(inp)
28
  if not loaded_model == type:
29
  model = getattr(MOD, model_master[type]["model_class"])(None, "cpu")
30
  model.load_state_dict(torch.load(model_master[type]["model_checkpoint"], map_location="cpu"))
31
  loaded_model = type
32
 
33
- op = model(inp).detach().squeeze()
34
- print("processed")
35
- return str(file)+str(type)+str(model_master[type]["eer_threshold"]) + str(op)
 
 
 
36
 
37
  demo = gr.Blocks()
38
  file_proc = gr.Interface(
@@ -40,7 +44,7 @@ file_proc = gr.Interface(
40
  inputs=[
41
  gr.Audio(sources=["upload"], label="Audio file", type="filepath"),
42
  gr.Radio(["SSL-AASIST (Trained on ASV-Spoof5)", "AASIST"], label="Select Model", type="value"),
43
- ],
44
  outputs="text",
45
  title="Find the Fake: Analyze 'Real' or 'Fake'.",
46
  description=(
@@ -53,4 +57,4 @@ file_proc = gr.Interface(
53
  with demo:
54
  gr.TabbedInterface([file_proc], ["Analyze Audio File"])
55
  demo.queue(max_size=10)
56
- demo.launch(share=True)
 
3
  import models as MOD
4
  import process_data as PD
5
 
6
+ model_master = {
7
+ "SSL-AASIST (Trained on ASV-Spoof5)": {"eer_threshold": 3.3330237865448,
8
+ "data_process_func": "process_ssl_assist_input",
9
  "note": "This model is trained on ASVSpoof 2024 training data.",
10
  "model_class": "Model",
11
+ "model_checkpoint": "ssl_aasist_epoch_7.pth"},
12
+ "AASIST": {"eer_threshold": 1.8018419742584229,
13
+ "data_process_func": "process_assist_input",
14
+ "note": "This model is trained on ASVSpoof 2024 training data."}
15
  }
16
 
17
+ model = MOD.Model(None, "cpu")
18
+ model.load_state_dict(torch.load("ssl_aasist_epoch_7.pth", map_location="cpu"))
 
19
  loaded_model = "SSL-AASIST (Trained on ASV-Spoof5)"
20
  print("model loaded")
21
 
22
+
23
  def process(file, type):
24
+ if type == "AASIST":
25
+ return "Model AASIST is not yet implemented."
26
  global model
27
  global loaded_model
28
  inp = getattr(PD, model_master[type]["data_process_func"])(file)
 
29
  if not loaded_model == type:
30
  model = getattr(MOD, model_master[type]["model_class"])(None, "cpu")
31
  model.load_state_dict(torch.load(model_master[type]["model_checkpoint"], map_location="cpu"))
32
  loaded_model = type
33
 
34
+ op = model(inp).detach().squeeze()[1].item()
35
+
36
+ response_text = "Decision score: {} \n Decision threshold: {} \n Note: 1. Any score below threshold is considered fake. \n 2. {} ".format(
37
+ str(op), str(model_master[type]["eer_threshold"]), model_master[type]["note"])
38
+ return response_text
39
+
40
 
41
  demo = gr.Blocks()
42
  file_proc = gr.Interface(
 
44
  inputs=[
45
  gr.Audio(sources=["upload"], label="Audio file", type="filepath"),
46
  gr.Radio(["SSL-AASIST (Trained on ASV-Spoof5)", "AASIST"], label="Select Model", type="value"),
47
+ ],
48
  outputs="text",
49
  title="Find the Fake: Analyze 'Real' or 'Fake'.",
50
  description=(
 
57
  with demo:
58
  gr.TabbedInterface([file_proc], ["Analyze Audio File"])
59
  demo.queue(max_size=10)
60
+ demo.launch(share=True)