Mayank022 commited on
Commit
ea70237
·
verified ·
1 Parent(s): b2ac71b

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +54 -0
inference.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torchaudio
4
+ import transformers
5
+ from config import ModelConfig
6
+ from model import MultiModalModel
7
+
8
+ def run_inference(audio_path: str, model_path: str = None):
9
+ # Load Config & Model
10
+ config = ModelConfig()
11
+
12
+
13
+ model = MultiModalModel(config)
14
+
15
+ if model_path:
16
+ state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu")
17
+ model.load_state_dict(state_dict, strict=False)
18
+
19
+ model.eval()
20
+
21
+ # Process Audio
22
+ processor = transformers.AutoProcessor.from_pretrained(config.audio_model_id)
23
+ audio, sr = torchaudio.load(audio_path)
24
+ if sr != 16000:
25
+ audio = torchaudio.functional.resample(audio, sr, 16000)
26
+ if audio.shape[0] > 1:
27
+ audio = audio.mean(dim=0, keepdim=True)
28
+
29
+ audio_inputs = processor(audio.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
30
+ audio_values = audio_inputs.input_features
31
+
32
+ # Create Input Text
33
+ tokenizer = transformers.AutoTokenizer.from_pretrained(config.text_model_id)
34
+ text = "Transcribe the following audio:"
35
+ text_inputs = tokenizer(text, return_tensors="pt")
36
+
37
+ # Generate
38
+ with torch.no_grad():
39
+ generated_ids = model.generate(
40
+ input_ids=text_inputs.input_ids,
41
+ audio_values=audio_values,
42
+ max_new_tokens=200
43
+ )
44
+
45
+ transcription = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
46
+ print("Transcription:", transcription)
47
+ return transcription
48
+
49
+ if __name__ == "__main__":
50
+ import sys
51
+ if len(sys.argv) > 1:
52
+ run_inference(sys.argv[1])
53
+ else:
54
+ print("Usage: python -m audio_lm.inference path/to/audio.wav")