xushengyuan commited on
Commit
fc3ecd6
·
1 Parent(s): 5ac3586

move auto device select to handler& optimized swa impl for client side infer

Browse files
Files changed (3) hide show
  1. acestep/handler.py +8 -1
  2. acestep/optimized_swa.py +115 -0
  3. test.py +1 -11
acestep/handler.py CHANGED
@@ -146,7 +146,14 @@ class AceStepHandler:
146
  """
147
  try:
148
  if device == "auto":
149
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
150
 
151
  status_msg = ""
152
 
 
146
  """
147
  try:
148
  if device == "auto":
149
+ if hasattr(torch, 'xpu') and torch.xpu.is_available():
150
+ device = "xpu"
151
+ elif torch.cuda.is_available():
152
+ device = "cuda"
153
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
154
+ device = "mps"
155
+ else:
156
+ device = "cpu"
157
 
158
  status_msg = ""
159
 
acestep/optimized_swa.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+ def optimized_sliding_window_attention(
6
+ query: torch.Tensor,
7
+ key: torch.Tensor,
8
+ value: torch.Tensor,
9
+ window_size: int,
10
+ scaling: float = None,
11
+ ) -> torch.Tensor:
12
+ """
13
+ Block-wise Sliding Window Attention implementation using PyTorch Eager Mode.
14
+
15
+ Args:
16
+ query: [Batch, Heads, Seq_Len, Head_Dim]
17
+ key: [Batch, Heads, Seq_Len, Head_Dim]
18
+ value: [Batch, Heads, Seq_Len, Head_Dim]
19
+ window_size: int, sliding window radius (one-sided)
20
+ scaling: float, scaling factor for attention scores (default: 1 / sqrt(head_dim))
21
+
22
+ Returns:
23
+ output: [Batch, Heads, Seq_Len, Head_Dim]
24
+ """
25
+ b, h, l, d = query.shape
26
+
27
+ if scaling is None:
28
+ scaling = 1.0 / math.sqrt(d)
29
+
30
+ # 1. Padding Query to be multiple of window_size
31
+ pad_len = (window_size - (l % window_size)) % window_size
32
+ if pad_len > 0:
33
+ query = F.pad(query, (0, 0, 0, pad_len))
34
+ # We also need to pad key/value to match length for the main structure,
35
+ # though we will add extra padding for the window later.
36
+ # Actually, for K/V, we just need them to be long enough to cover the windows.
37
+ # Let's pad them to match Q's padded length first to simplify indexing.
38
+ key = F.pad(key, (0, 0, 0, pad_len))
39
+ value = F.pad(value, (0, 0, 0, pad_len))
40
+
41
+ l_padded = query.shape[2]
42
+ num_chunks = l_padded // window_size
43
+
44
+ # 2. Prepare Key/Value with halo padding
45
+ # We need [i*W - W : (i+1)*W + W] for each chunk i.
46
+ # So we pad W on both sides of the sequence dimension.
47
+ # K shape: [B, H, L_padded, D] -> [B, H, W + L_padded + W, D]
48
+ key_padded = F.pad(key, (0, 0, window_size, window_size))
49
+ value_padded = F.pad(value, (0, 0, window_size, window_size))
50
+
51
+ # 3. Chunking Query
52
+ # [B, H, L_padded, D] -> [B, H, Num_Chunks, W, D]
53
+ query_chunks = query.view(b, h, num_chunks, window_size, d)
54
+
55
+ # 4. Unfolding Key/Value
56
+ # We want windows of size 3*W with stride W.
57
+ # Input dim: [B, H, L_padded + 2W, D]
58
+ # Unfold on dim 2.
59
+ # Result: [B, H, Num_Chunks, D, 3*W]
60
+ key_chunks = key_padded.unfold(2, 3 * window_size, window_size)
61
+ value_chunks = value_padded.unfold(2, 3 * window_size, window_size)
62
+
63
+ # Adjust shapes for matmul: [B, H, Num_Chunks, 3*W, D]
64
+ key_chunks = key_chunks.transpose(-1, -2)
65
+ value_chunks = value_chunks.transpose(-1, -2)
66
+
67
+ # 5. Attention Scores
68
+ # Q: [..., W, D], K: [..., 3W, D] -> Scores: [..., W, 3W]
69
+ scores = torch.matmul(query_chunks, key_chunks.transpose(-1, -2)) * scaling
70
+
71
+ # 6. Apply Local Mask
72
+ # Construct mask once
73
+ # q_idx in [0, W), k_idx in [0, 3W)
74
+ # Valid if k_idx in [q_idx, q_idx + 2W]
75
+
76
+ local_q_idx = torch.arange(window_size, device=query.device).unsqueeze(1) # [W, 1]
77
+ local_k_idx = torch.arange(3 * window_size, device=query.device).unsqueeze(0) # [1, 3W]
78
+
79
+ # Geometric mask
80
+ mask = (local_k_idx >= local_q_idx) & (local_k_idx <= (local_q_idx + 2 * window_size))
81
+ # [1, 1, 1, W, 3W]
82
+ mask = mask.view(1, 1, 1, window_size, 3 * window_size)
83
+
84
+ # Padding mask
85
+ # We need to mask out keys that are padding (either halo or alignment padding)
86
+ # Valid keys in key_padded are at indices [window_size, window_size + l)
87
+ valid_key_mask = torch.zeros(l_padded + 2 * window_size, device=query.device, dtype=torch.bool)
88
+ valid_key_mask[window_size : window_size + l] = True
89
+
90
+ # Unfold to match key_chunks: [Num_Chunks, 3W]
91
+ valid_key_mask_chunks = valid_key_mask.unfold(0, 3 * window_size, window_size)
92
+ # Reshape to broadcast: [1, 1, Num_Chunks, 1, 3W]
93
+ valid_key_mask_chunks = valid_key_mask_chunks.view(1, 1, num_chunks, 1, 3 * window_size)
94
+
95
+ # Combine masks
96
+ mask = mask & valid_key_mask_chunks
97
+
98
+ # Apply mask
99
+ min_dtype = torch.finfo(scores.dtype).min
100
+ scores = scores.masked_fill(~mask, min_dtype)
101
+
102
+ # 7. Softmax and Weighted Sum
103
+ attn_probs = F.softmax(scores, dim=-1)
104
+ # [..., W, 3W] @ [..., 3W, D] -> [..., W, D]
105
+ output_chunks = torch.matmul(attn_probs, value_chunks)
106
+
107
+ # 8. Reshape and Crop
108
+ # [B, H, Num_Chunks, W, D] -> [B, H, L_padded, D]
109
+ output = output_chunks.view(b, h, l_padded, d)
110
+
111
+ # Remove padding
112
+ if pad_len > 0:
113
+ output = output[:, :, :l, :]
114
+
115
+ return output
test.py CHANGED
@@ -35,23 +35,13 @@ def main():
35
  print(f"Using model: {model_name}")
36
 
37
  # Initialize service
38
- if hasattr(torch, 'xpu') and torch.xpu.is_available():
39
- device = "xpu"
40
- elif torch.cuda.is_available():
41
- device = "cuda"
42
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
43
- device = "mps"
44
- else:
45
- device = "cpu"
46
- print(f"Using device: {device}")
47
 
48
  use_llm = False
49
 
50
  status, enabled = handler.initialize_service(
51
  project_root=project_root,
52
  config_path=model_name,
53
- device=device,
54
- init_llm=use_llm,
55
  use_flash_attention=True, # Default in UI
56
  compile_model=True,
57
  offload_to_cpu=True,
 
35
  print(f"Using model: {model_name}")
36
 
37
  # Initialize service
 
 
 
 
 
 
 
 
 
38
 
39
  use_llm = False
40
 
41
  status, enabled = handler.initialize_service(
42
  project_root=project_root,
43
  config_path=model_name,
44
+ device='auto',
 
45
  use_flash_attention=True, # Default in UI
46
  compile_model=True,
47
  offload_to_cpu=True,