mariam-ahmed15 commited on
Commit
3fc49d8
·
verified ·
1 Parent(s): b7e88e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -15
app.py CHANGED
@@ -5,34 +5,31 @@ from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtra
5
 
6
  # 1. CONFIGURATION
7
  MODEL_ID = "facebook/wav2vec2-xls-r-300m"
8
- QUANTIZED_MODEL_PATH = "quantized_model.pth"
9
 
10
- # 2. LOAD MODEL
11
  print("Loading model architecture...")
12
- # A. Load the skeleton (empty weights)
13
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_ID, num_labels=2)
14
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID)
15
 
16
- # B. Apply the quantization structure (Must happen BEFORE loading weights)
17
- # This changes the Linear layers to INT8 format so the keys match
18
  model = torch.quantization.quantize_dynamic(
19
  model, {torch.nn.Linear}, dtype=torch.qint8
20
  )
21
 
22
- # C. Load your trained quantized weights
23
  print("Loading quantized weights...")
24
  model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH, map_location=torch.device('cpu')))
25
  model.eval()
26
 
27
- # 3. DEFINE PREDICTION FUNCTION
28
  def predict_audio(audio_path):
29
  if audio_path is None:
30
  return "No Audio Provided"
31
 
32
- # Load and resample audio to 16kHz
33
  speech_array, sr = librosa.load(audio_path, sr=16000)
34
 
35
- # Process inputs
36
  inputs = feature_extractor(
37
  speech_array,
38
  sampling_rate=16000,
@@ -43,10 +40,9 @@ def predict_audio(audio_path):
43
  with torch.no_grad():
44
  logits = model(**inputs).logits
45
 
46
- # Convert logits to probabilities
47
  probs = torch.nn.functional.softmax(logits, dim=-1)
48
 
49
- # Assuming Label 0 = Real, Label 1 = Deepfake (Adjust based on your training!)
50
  fake_prob = probs[0][1].item()
51
  real_prob = probs[0][0].item()
52
 
@@ -55,14 +51,17 @@ def predict_audio(audio_path):
55
  "Real": real_prob
56
  }
57
 
58
- # 4. CREATE API INTERFACE
59
- # This creates a visual UI *and* a hidden API endpoint
60
  iface = gr.Interface(
61
  fn=predict_audio,
62
- inputs=gr.Audio(type="filepath"),
 
 
 
 
63
  outputs=gr.Label(num_top_classes=2),
64
  title="Deepfake Audio Detection API",
65
- description="Upload an audio file to check if it's real or fake."
66
  )
67
 
68
  iface.launch()
 
5
 
6
  # 1. CONFIGURATION
7
  MODEL_ID = "facebook/wav2vec2-xls-r-300m"
8
+ QUANTIZED_MODEL_PATH = "quantized_model.pth"
9
 
10
+ # 2. LOAD MODEL
11
  print("Loading model architecture...")
 
12
  model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_ID, num_labels=2)
13
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_ID)
14
 
15
+ # Apply quantization structure
 
16
  model = torch.quantization.quantize_dynamic(
17
  model, {torch.nn.Linear}, dtype=torch.qint8
18
  )
19
 
20
+ # Load weights
21
  print("Loading quantized weights...")
22
  model.load_state_dict(torch.load(QUANTIZED_MODEL_PATH, map_location=torch.device('cpu')))
23
  model.eval()
24
 
25
+ # 3. PREDICTION FUNCTION
26
  def predict_audio(audio_path):
27
  if audio_path is None:
28
  return "No Audio Provided"
29
 
30
+ # Load and resample
31
  speech_array, sr = librosa.load(audio_path, sr=16000)
32
 
 
33
  inputs = feature_extractor(
34
  speech_array,
35
  sampling_rate=16000,
 
40
  with torch.no_grad():
41
  logits = model(**inputs).logits
42
 
 
43
  probs = torch.nn.functional.softmax(logits, dim=-1)
44
 
45
+ # Label 0 = Real, Label 1 = Deepfake (Double check your own labels!)
46
  fake_prob = probs[0][1].item()
47
  real_prob = probs[0][0].item()
48
 
 
51
  "Real": real_prob
52
  }
53
 
54
+ # 4. CREATE INTERFACE (Modified for Upload Only)
 
55
  iface = gr.Interface(
56
  fn=predict_audio,
57
+ inputs=gr.Audio(
58
+ sources=["upload"],
59
+ type="filepath",
60
+ label="Upload Audio File"
61
+ ),
62
  outputs=gr.Label(num_top_classes=2),
63
  title="Deepfake Audio Detection API",
64
+ description="Upload an audio file (WAV/MP3) to check if it's real or fake."
65
  )
66
 
67
  iface.launch()