aTrapDeer commited on
Commit
231df5f
·
verified ·
1 Parent(s): 3ef959e

Bootstrap Qwen2-Audio custom endpoint repo

Browse files
Files changed (3) hide show
  1. README.md +62 -0
  2. handler.py +110 -0
  3. requirements.txt +6 -0
README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Qwen2-Audio Caption Endpoint Template
2
+
3
+ Use this as a custom `handler.py` runtime for a Hugging Face Dedicated Endpoint.
4
+
5
+ ## Request contract
6
+
7
+ ```json
8
+ {
9
+ "inputs": {
10
+ "prompt": "Analyze and describe this music segment.",
11
+ "audio_base64": "<base64-encoded WAV bytes>",
12
+ "sample_rate": 16000,
13
+ "max_new_tokens": 384,
14
+ "temperature": 0.1
15
+ }
16
+ }
17
+ ```
18
+
19
+ ## Response contract
20
+
21
+ ```json
22
+ {
23
+ "generated_text": "..."
24
+ }
25
+ ```
26
+
27
+ ## Setup
28
+
29
+ Fastest way from this repo:
30
+
31
+ ```bash
32
+ python scripts/hf_clone.py qwen-endpoint --repo-id YOUR_USERNAME/YOUR_QWEN_ENDPOINT_REPO
33
+ ```
34
+
35
+ Then deploy a Dedicated Endpoint from that repo with task `custom`.
36
+
37
+ Manual path:
38
+
39
+ 1. Create a new model repo for your endpoint runtime.
40
+ 2. Copy `handler.py` from this folder into that repo as top-level `handler.py`.
41
+ 3. Add a `requirements.txt` containing at least:
42
+ - `torch`
43
+ - `torchaudio`
44
+ - `transformers>=4.53.0,<4.58.0`
45
+ - `soundfile`
46
+ - `numpy`
47
+ 4. Deploy a Dedicated Endpoint from that repo.
48
+ 5. Optional endpoint env var:
49
+ - `QWEN_MODEL_ID=Qwen/Qwen2-Audio-7B-Instruct`
50
+
51
+ Then point `qwen_caption_app.py` backend `hf_endpoint` at that endpoint URL.
52
+
53
+ ## Quick local test script
54
+
55
+ From this repo:
56
+
57
+ ```bash
58
+ python scripts/endpoint/test_qwen_caption_endpoint.py \
59
+ --url https://YOUR_ENDPOINT.endpoints.huggingface.cloud \
60
+ --token hf_xxx \
61
+ --audio path/to/song.wav
62
+ ```
handler.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ from typing import Any, Dict
5
+
6
+ import numpy as np
7
+ import soundfile as sf
8
+ import torch
9
+ from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
10
+
11
+
12
+ def _decode_audio_b64(audio_b64: str):
13
+ raw = base64.b64decode(audio_b64)
14
+ audio, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=True)
15
+ mono = audio.mean(axis=1).astype(np.float32)
16
+ return mono, int(sr)
17
+
18
+
19
+ class EndpointHandler:
20
+ """
21
+ HF Dedicated Endpoint custom handler contract:
22
+ request:
23
+ {
24
+ "inputs": {
25
+ "prompt": "...",
26
+ "audio_base64": "...",
27
+ "sample_rate": 16000,
28
+ "max_new_tokens": 384,
29
+ "temperature": 0.1
30
+ }
31
+ }
32
+ response:
33
+ {"generated_text": "..."}
34
+ """
35
+
36
+ def __init__(self, model_dir: str = ""):
37
+ model_id = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2-Audio-7B-Instruct")
38
+ if model_dir and os.path.isdir(model_dir):
39
+ # Allows loading from files packaged in endpoint model repo.
40
+ model_id = model_dir
41
+
42
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
43
+ self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
44
+ self.model = Qwen2AudioForConditionalGeneration.from_pretrained(
45
+ model_id,
46
+ torch_dtype=dtype,
47
+ trust_remote_code=True,
48
+ )
49
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ self.model.to(self.device)
51
+
52
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
53
+ payload = data.get("inputs", data) if isinstance(data, dict) else {}
54
+ prompt = str(payload.get("prompt", "Analyze this music audio.")).strip()
55
+ audio_b64 = payload.get("audio_base64")
56
+ if not audio_b64:
57
+ return {"error": "audio_base64 is required"}
58
+
59
+ max_new_tokens = int(payload.get("max_new_tokens", 384))
60
+ temperature = float(payload.get("temperature", 0.1))
61
+
62
+ audio, sr = _decode_audio_b64(audio_b64)
63
+ sampling_rate = int(payload.get("sample_rate", sr))
64
+
65
+ conversation = [
66
+ {"role": "system", "content": "You are a precise music analysis assistant."},
67
+ {
68
+ "role": "user",
69
+ "content": [
70
+ {"type": "audio", "audio_url": "local://audio.wav"},
71
+ {"type": "text", "text": prompt},
72
+ ],
73
+ },
74
+ ]
75
+ chat_text = self.processor.apply_chat_template(
76
+ conversation,
77
+ add_generation_prompt=True,
78
+ tokenize=False,
79
+ )
80
+ inputs = self.processor(
81
+ text=chat_text,
82
+ audios=[audio],
83
+ sampling_rate=sampling_rate,
84
+ return_tensors="pt",
85
+ padding=True,
86
+ )
87
+
88
+ device = next(self.model.parameters()).device
89
+ for key, value in list(inputs.items()):
90
+ if hasattr(value, "to"):
91
+ inputs[key] = value.to(device)
92
+
93
+ do_sample = bool(temperature and temperature > 0)
94
+ gen_kwargs = {
95
+ "max_new_tokens": int(max_new_tokens),
96
+ "do_sample": do_sample,
97
+ }
98
+ if do_sample:
99
+ gen_kwargs["temperature"] = max(float(temperature), 1e-5)
100
+
101
+ with torch.no_grad():
102
+ generated_ids = self.model.generate(**inputs, **gen_kwargs)
103
+ prompt_tokens = inputs["input_ids"].shape[1]
104
+ generated_ids = generated_ids[:, prompt_tokens:]
105
+ text = self.processor.batch_decode(
106
+ generated_ids,
107
+ skip_special_tokens=True,
108
+ clean_up_tokenization_spaces=False,
109
+ )[0]
110
+ return {"generated_text": text.strip()}
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ soundfile
4
+ numpy
5
+ transformers>=4.53.0,<4.58.0
6
+ accelerate