madharjan commited on
Commit
b074071
·
1 Parent(s): b2f8d54

Refactor model loading and inference logic in app.py; update requirements.txt for package versions

Browse files
Files changed (2) hide show
  1. app.py +71 -38
  2. requirements.txt +6 -10
app.py CHANGED
@@ -2,60 +2,93 @@ import gradio as gr
2
  import torch
3
  import librosa
4
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
 
5
 
6
- # Load model and processor
7
- repo_id = "MERaLiON/MERaLiON-2-10B"
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
11
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
12
- repo_id,
13
- use_safetensors=True,
14
- trust_remote_code=True,
15
- attn_implementation="flash_attention_2",
16
- torch_dtype=torch.bfloat16
17
- ).to(device)
 
 
 
 
 
 
 
 
 
 
18
 
19
  def meralion_inference(prompt, uploaded_file):
 
 
20
  if uploaded_file is None:
21
  return "Please upload an audio file."
22
-
23
- # Prompt template and example prompts
24
- prompt_template = "Instruction: {query} \nFollow the text instruction based on the following audio: <SpeechHere>"
25
 
26
- audio_array, sr = librosa.load(uploaded_file.name, sr=16000)
 
27
 
28
- # Create conversation and apply chat template
29
- conversation = [{"role": "user", "content": prompt_template.format(query=prompt)}]
30
- chat_prompt = processor.tokenizer.apply_chat_template(
31
- conversation=conversation, tokenize=False, add_generation_prompt=True
32
- )
33
 
34
- # Process inputs
35
- inputs = processor(text=chat_prompt, audios=audio_array)
 
 
 
 
 
 
36
 
37
- # Move tensors to device and cast float32 to bfloat16
38
- for key, value in inputs.items():
39
- if isinstance(value, torch.Tensor):
40
- inputs[key] = inputs[key].to(device)
41
- if value.dtype == torch.float32:
42
- inputs[key] = inputs[key].to(torch.bfloat16)
43
 
44
- # Generate response
45
- outputs = model.generate(**inputs, max_new_tokens=256)
46
- generated_ids = outputs[:, inputs['input_ids'].size(1):]
47
- response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
48
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  with gr.Blocks() as demo:
52
- gr.Markdown("# Meralion Model Demo with Prompt and File Upload")
53
  with gr.Row():
54
- prompt_input = gr.Textbox(label="Enter Prompt")
55
- file_input = gr.File(label="Upload File")
56
- output_text = gr.Textbox(label="Model Output")
 
 
 
 
 
57
 
58
- submit_btn = gr.Button("Run Model")
59
- submit_btn.click(meralion_inference, inputs=[prompt_input, file_input], outputs=output_text)
 
 
60
 
61
  demo.launch()
 
2
  import torch
3
  import librosa
4
  from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
5
+ import os
6
 
7
+ # Global model cache
8
+ model = None
9
+ processor = None
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+
13
+ def load_model():
14
+ global model, processor
15
+ if model is None:
16
+ repo_id = "MERaLiON/MERaLiON-2-10B"
17
+ print("Loading MERaLiON-2-10B model...")
18
+ processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
19
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
20
+ repo_id,
21
+ use_safetensors=True,
22
+ trust_remote_code=True,
23
+ attn_implementation="flash_attention_2",
24
+ torch_dtype=torch.bfloat16,
25
+ device_map="auto",
26
+ )
27
+ print("Model loaded successfully!")
28
+ return model, processor
29
+
30
 
31
  def meralion_inference(prompt, uploaded_file):
32
+ global model, processor
33
+
34
  if uploaded_file is None:
35
  return "Please upload an audio file."
 
 
 
36
 
37
+ # Load model on first run
38
+ model, processor = load_model()
39
 
40
+ try:
41
+ # Load audio at 16kHz
42
+ audio_array, sr = librosa.load(uploaded_file.name, sr=16000)
 
 
43
 
44
+ # Prompt template
45
+ prompt_template = "Instruction: {query}\nFollow the text instruction based on the following audio: <SpeechHere>"
46
+ conversation = [
47
+ {"role": "user", "content": prompt_template.format(query=prompt)}
48
+ ]
49
+ chat_prompt = processor.tokenizer.apply_chat_template(
50
+ conversation=conversation, tokenize=False, add_generation_prompt=True
51
+ )
52
 
53
+ # Process inputs
54
+ inputs = processor(text=chat_prompt, audios=audio_array)
 
 
 
 
55
 
56
+ # Move to device and fix dtype
57
+ for key, value in inputs.items():
58
+ if isinstance(value, torch.Tensor):
59
+ inputs[key] = value.to(device)
60
+ if value.dtype == torch.float32:
61
+ inputs[key] = inputs[key].to(torch.bfloat16)
62
+
63
+ # Generate
64
+ with torch.no_grad():
65
+ outputs = model.generate(
66
+ **inputs, max_new_tokens=256, do_sample=True, temperature=0.7
67
+ )
68
+ generated_ids = outputs[:, inputs["input_ids"].size(1) :]
69
+ response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
70
+
71
+ return response
72
+
73
+ except Exception as e:
74
+ return f"Error during inference: {str(e)}"
75
 
76
 
77
  with gr.Blocks() as demo:
78
+ gr.Markdown("# MERaLiON-2-10B Audio Demo")
79
  with gr.Row():
80
+ prompt_input = gr.Textbox(
81
+ label="Enter Prompt", value="Please transcribe this speech.", lines=2
82
+ )
83
+ file_input = gr.File(
84
+ label="Upload Audio File (WAV/MP3, max 300s)",
85
+ file_types=[".wav", ".mp3", ".m4a"],
86
+ )
87
+ output_text = gr.Textbox(label="Model Output", lines=8)
88
 
89
+ submit_btn = gr.Button("Run Inference", variant="primary")
90
+ submit_btn.click(
91
+ meralion_inference, inputs=[prompt_input, file_input], outputs=output_text
92
+ )
93
 
94
  demo.launch()
requirements.txt CHANGED
@@ -1,13 +1,9 @@
1
- # Pinned non-torch packages (change versions if you need newer/stable ones)
2
- gradio==3.50.1
3
- transformers==4.35.2
4
  librosa==0.10.0
5
- safetensors==0.3.2
6
- accelerate==0.20.3
7
  soundfile==0.12.1
8
-
9
- # NOTE: `torch` should be installed via the official PyTorch wheels that match
10
- # your CUDA version (or CPU-only). See the README.md for Windows CPU/CUDA
11
- # install commands and pick the appropriate wheel. To keep this file simple
12
- # we do not pin `torch` here.
13
  torch
 
 
1
+ # Core requirements for MERaLiON-2-10B
2
+ transformers==4.50.1
3
+ gradio==4.44.0
4
  librosa==0.10.0
5
+ safetensors==0.4.5
6
+ accelerate==0.41.0
7
  soundfile==0.12.1
 
 
 
 
 
8
  torch
9
+ flash-attn