OzTianlu commited on
Commit
ec2aaa8
·
verified ·
1 Parent(s): e03da02

Upload 4 files

Browse files
Files changed (4) hide show
  1. AsteriskForCausalLM.py +473 -0
  2. README.md +467 -3
  3. handler.py +126 -0
  4. requirements.txt +3 -0
AsteriskForCausalLM.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hybrid ASPP-Attention Architecture (Asterisk Model)
3
+ Combines Adjacency-Structured Parallel Propagation (ASPP) with standard attention mechanisms
4
+ to enhance model expressiveness while maintaining efficiency.
5
+
6
+ Architecture Design:
7
+ - Hybrid layers: Standard attention + ASPP operator in parallel
8
+ - Gate mechanism for dynamic fusion
9
+ - Knowledge distillation from SmolLM2-135M base model
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
16
+ from transformers.models.llama.modeling_llama import (
17
+ LlamaAttention,
18
+ LlamaDecoderLayer,
19
+ LlamaRMSNorm,
20
+ LlamaMLP,
21
+ )
22
+ from transformers import AutoConfig, AutoModelForCausalLM
23
+ from typing import Optional, Tuple, List
24
+
25
+
26
+ class AsteriskConfig(LlamaConfig):
27
+ """
28
+ Configuration class for Asterisk model.
29
+ Inherits from LlamaConfig with custom model_type.
30
+ """
31
+ model_type = "asterisk"
32
+
33
+ def __init__(
34
+ self,
35
+ hybrid_layer_indices: Optional[List[int]] = None,
36
+ aspp_hidden_dim: Optional[int] = None,
37
+ aspp_num_steps: int = 2,
38
+ aspp_dropout: float = 0.1,
39
+ aspp_num_neighbors: int = 1, # Fixed at 1 for Union-Find (only parent)
40
+ # π-flow parameters
41
+ pi_flow: bool = False,
42
+ pi_flow_steps: int = 1,
43
+ pi_flow_scale: float = 0.2,
44
+ pi_flow_use_gate: bool = True,
45
+ **kwargs
46
+ ):
47
+ super().__init__(**kwargs)
48
+ self.hybrid_layer_indices = hybrid_layer_indices
49
+ self.aspp_hidden_dim = aspp_hidden_dim
50
+ self.aspp_num_steps = aspp_num_steps
51
+ self.aspp_dropout = aspp_dropout
52
+ self.aspp_num_neighbors = aspp_num_neighbors
53
+ # π-flow config
54
+ self.pi_flow = pi_flow
55
+ self.pi_flow_steps = pi_flow_steps
56
+ self.pi_flow_scale = pi_flow_scale
57
+ self.pi_flow_use_gate = pi_flow_use_gate
58
+
59
+
60
+ class ASPPOperator(nn.Module):
61
+ """
62
+ Asterisk Operator (ASPP) - Union-Find Graph Propagation
63
+
64
+ Uses Union-Find (Disjoint Set Union) structure for dynamic parent connections:
65
+ - Each position maintains a parent pointer: parent[i]
66
+ - Initial structure: parent[i] = max(0, i-1) (linear chain)
67
+ - Message passing: aggregate self + parent features
68
+ - Can apply path compression for optimization
69
+
70
+ Advantages:
71
+ - O(n) complexity with simple indexing
72
+ - Dynamic grouping of related positions
73
+ - Efficient parent-only propagation (no complex gather)
74
+ - Nearly constant time find with path compression
75
+
76
+ Complexity: O(n) with α(n) ≈ O(1) per operation
77
+ Message passing: h_i^(t+1) = φ(h_i^(t), h_parent[i])
78
+
79
+ Args:
80
+ hidden_size: Dimension of hidden states (input/output)
81
+ aspp_hidden_dim: Internal dimension for ASPP (default: None, use hidden_size)
82
+ num_steps: Number of evolution steps K (default: 2)
83
+ dropout: Dropout rate for regularization (default: 0.1)
84
+ num_neighbors: Fixed at 1 (only parent) for Union-Find structure
85
+ """
86
+
87
+ def __init__(self, hidden_size: int, aspp_hidden_dim: Optional[int] = None, num_steps: int = 2, dropout: float = 0.1, num_neighbors: int = 1):
88
+ super().__init__()
89
+ self.hidden_size = hidden_size
90
+ self.aspp_hidden_dim = aspp_hidden_dim or hidden_size
91
+ self.num_steps = num_steps
92
+ self.num_neighbors = 1 # Fixed: only parent
93
+
94
+ # Projection to lower dimension (if specified)
95
+ self.use_projection = (self.aspp_hidden_dim != hidden_size)
96
+ if self.use_projection:
97
+ self.down_proj = nn.Linear(hidden_size, self.aspp_hidden_dim)
98
+ self.up_proj = nn.Linear(self.aspp_hidden_dim, hidden_size)
99
+ self.proj_dropout = nn.Dropout(dropout)
100
+
101
+ # Message aggregation function: combines self + parent
102
+ self.message_net = nn.Sequential(
103
+ nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim * 2),
104
+ nn.SiLU(),
105
+ nn.Dropout(dropout),
106
+ nn.Linear(self.aspp_hidden_dim * 2, self.aspp_hidden_dim),
107
+ nn.Dropout(dropout),
108
+ )
109
+
110
+ # Learnable K-step parameter
111
+ self.k_logit = nn.Parameter(torch.tensor(1.0))
112
+
113
+ # Learnable residual scale
114
+ self.residual_scale = nn.Parameter(torch.tensor(0.1))
115
+
116
+ # Layer norm for stability
117
+ self.norm = nn.LayerNorm(self.aspp_hidden_dim, eps=1e-5)
118
+
119
+ def compute_parent_indices(self, seq_len: int, device) -> torch.Tensor:
120
+ """
121
+ Compute parent index for each position using Union-Find structure
122
+
123
+ Simple implementation: parent[i] = i-1 (linear chain)
124
+ - Position 0 points to itself (root)
125
+ - All others point to previous position
126
+
127
+ Can be extended with dynamic union operations based on:
128
+ - Semantic similarity
129
+ - Positional heuristics
130
+ - Learned grouping
131
+
132
+ Returns: [seq_len] tensor of parent indices
133
+ """
134
+ # Initialize: parent[i] = max(0, i-1)
135
+ parent_indices = torch.arange(seq_len, device=device) - 1
136
+ parent_indices[0] = 0 # Root points to itself
137
+ parent_indices = torch.clamp(parent_indices, 0, seq_len - 1)
138
+
139
+ return parent_indices
140
+
141
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
142
+ """
143
+ Args:
144
+ hidden_states: [batch_size, seq_len, hidden_size]
145
+ Returns:
146
+ evolved_states: [batch_size, seq_len, hidden_size]
147
+ """
148
+ batch_size, seq_len, _ = hidden_states.shape
149
+
150
+ # Project to lower dimension if needed
151
+ if self.use_projection:
152
+ h_t = self.down_proj(hidden_states)
153
+ h_t = self.proj_dropout(h_t)
154
+ else:
155
+ h_t = hidden_states
156
+
157
+ # Learnable number of steps
158
+ k_steps = max(1, int(torch.sigmoid(self.k_logit) * self.num_steps))
159
+
160
+ # K-step Union-Find graph propagation
161
+ for t in range(k_steps):
162
+ # 1. Compute parent indices using Union-Find structure
163
+ parent_indices = self.compute_parent_indices(seq_len, h_t.device) # [L]
164
+
165
+ # 2. Gather parent features (super simple indexing!)
166
+ # h_t: [B, L, D], parent_indices: [L]
167
+ # Just gather from parent positions
168
+ parent_features = h_t[:, parent_indices, :] # [B, L, D]
169
+
170
+ # 3. Message passing: combine self + parent
171
+ message_input = torch.cat([h_t, parent_features], dim=-1) # [B, L, 2D]
172
+ h_t_next = self.message_net(message_input) # [B, L, D]
173
+
174
+ # 4. Scaled residual connection for stability
175
+ h_t = h_t + self.residual_scale * h_t_next
176
+ h_t = self.norm(h_t)
177
+
178
+ # Project back to original dimension if needed
179
+ if self.use_projection:
180
+ h_t = self.up_proj(h_t)
181
+ h_t = self.proj_dropout(h_t)
182
+
183
+ return h_t
184
+
185
+
186
+ class HybridASPPAttentionLayer(LlamaDecoderLayer):
187
+ """
188
+ Hybrid layer combining ASPP operator and standard attention
189
+ Inherits from LlamaDecoderLayer to maintain compatibility
190
+
191
+ Architecture:
192
+ 1. Parallel branches:
193
+ - ASPP operator for local structured reasoning
194
+ - Standard LlamaAttention for global context
195
+ 2. Gated fusion of both outputs
196
+ 3. π-flow refinement (optional, per-layer)
197
+ 4. Feed-forward network
198
+ """
199
+
200
+ def __init__(self, config: LlamaConfig, layer_idx: int, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 1):
201
+ # Initialize parent LlamaDecoderLayer
202
+ super().__init__(config, layer_idx)
203
+
204
+ # Add ASPP branch
205
+ self.aspp_operator = ASPPOperator(
206
+ hidden_size=config.hidden_size,
207
+ aspp_hidden_dim=aspp_hidden_dim,
208
+ num_steps=aspp_num_steps,
209
+ dropout=aspp_dropout,
210
+ num_neighbors=aspp_num_neighbors
211
+ )
212
+
213
+ # Gated fusion mechanism with dropout
214
+ self.fusion_gate = nn.Sequential(
215
+ nn.Linear(config.hidden_size * 2, config.hidden_size),
216
+ nn.Dropout(aspp_dropout),
217
+ nn.Sigmoid()
218
+ )
219
+
220
+ # Initialize gate to be balanced (output 0.5 initially)
221
+ with torch.no_grad():
222
+ self.fusion_gate[0].bias.fill_(0.0) # sigmoid(0) = 0.5
223
+
224
+ # π-flow: Per-layer refinement ASPP
225
+ if getattr(config, 'pi_flow', False):
226
+ self.pi_flow_aspp = ASPPOperator(
227
+ hidden_size=config.hidden_size,
228
+ aspp_hidden_dim=aspp_hidden_dim,
229
+ num_steps=aspp_num_steps,
230
+ dropout=aspp_dropout,
231
+ num_neighbors=aspp_num_neighbors
232
+ )
233
+
234
+ # Learnable flow scale (per-layer)
235
+ self.pi_flow_scale = nn.Parameter(
236
+ torch.tensor(getattr(config, 'pi_flow_scale', 0.2))
237
+ )
238
+
239
+ # Token-wise adaptive gating (optional)
240
+ if getattr(config, 'pi_flow_use_gate', True):
241
+ self.pi_flow_gate = nn.Sequential(
242
+ nn.Linear(config.hidden_size, config.hidden_size // 4),
243
+ nn.SiLU(),
244
+ nn.Dropout(aspp_dropout),
245
+ nn.Linear(config.hidden_size // 4, 1),
246
+ nn.Sigmoid()
247
+ )
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ attention_mask: Optional[torch.Tensor] = None,
253
+ position_ids: Optional[torch.LongTensor] = None,
254
+ past_key_values = None,
255
+ use_cache: Optional[bool] = False,
256
+ cache_position: Optional[torch.LongTensor] = None,
257
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
258
+ **kwargs,
259
+ ) -> torch.Tensor:
260
+ """
261
+ Override LlamaDecoderLayer.forward to add ASPP branch and π-flow
262
+ Returns single tensor like LlamaDecoderLayer
263
+ """
264
+ residual = hidden_states
265
+ hidden_states = self.input_layernorm(hidden_states)
266
+
267
+ # ASPP branch
268
+ aspp_output = self.aspp_operator(hidden_states)
269
+
270
+ # Attention branch - use parent's self_attn (returns tuple, discard cache with _)
271
+ attn_output, _ = self.self_attn(
272
+ hidden_states=hidden_states,
273
+ attention_mask=attention_mask,
274
+ position_ids=position_ids,
275
+ past_key_values=past_key_values,
276
+ cache_position=cache_position,
277
+ position_embeddings=position_embeddings,
278
+ )
279
+
280
+ # Gated fusion
281
+ fusion_input = torch.cat([aspp_output, attn_output], dim=-1)
282
+ gate = self.fusion_gate(fusion_input)
283
+
284
+ # Combine with gating: gate * ASPP + (1-gate) * Attention
285
+ fused_output = gate * aspp_output + (1 - gate) * attn_output
286
+
287
+ # Residual connection
288
+ hidden_states = residual + fused_output
289
+
290
+ # π-flow: Multi-step refinement in probability space (per-layer)
291
+ if hasattr(self, 'pi_flow_aspp'):
292
+ pi_flow_steps = getattr(self.config if hasattr(self, 'config') else kwargs.get('config'), 'pi_flow_steps', 1)
293
+
294
+ for step in range(pi_flow_steps):
295
+ # Compute velocity field v(h) using ASPP
296
+ v = self.pi_flow_aspp(hidden_states)
297
+
298
+ # Compute adaptive gate (per-token flow strength)
299
+ if hasattr(self, 'pi_flow_gate'):
300
+ gate = self.pi_flow_gate(hidden_states) # [B, L, 1]
301
+ alpha = self.pi_flow_scale * gate
302
+ else:
303
+ alpha = self.pi_flow_scale
304
+
305
+ # Euler step: h' = h + α * v(h)
306
+ hidden_states = hidden_states + alpha * v
307
+
308
+ # MLP block (use parent's mlp)
309
+ residual = hidden_states
310
+ hidden_states = self.post_attention_layernorm(hidden_states)
311
+ hidden_states = self.mlp(hidden_states)
312
+ hidden_states = residual + hidden_states
313
+
314
+ # Return only hidden_states tensor, like LlamaDecoderLayer
315
+ return hidden_states
316
+
317
+
318
+ class AsteriskLlamaModel(LlamaModel):
319
+ """
320
+ Asterisk-Llama model with full hybrid ASPP-Attention architecture
321
+
322
+ All layers use hybrid ASPP+Attention by default for maximum expressiveness.
323
+ """
324
+
325
+ def __init__(self, config: LlamaConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2):
326
+ super().__init__(config)
327
+
328
+ # Determine which layers to make hybrid (default: ALL layers)
329
+ if hybrid_layer_indices is None:
330
+ # Use ALL layers as hybrid (full hybrid architecture)
331
+ num_layers = config.num_hidden_layers
332
+ hybrid_layer_indices = list(range(num_layers))
333
+
334
+ self.hybrid_layer_indices = hybrid_layer_indices
335
+
336
+ # Replace specified layers with hybrid layers (with per-layer π-flow if enabled)
337
+ for idx in hybrid_layer_indices:
338
+ if idx < len(self.layers):
339
+ self.layers[idx] = HybridASPPAttentionLayer(
340
+ config,
341
+ layer_idx=idx,
342
+ aspp_hidden_dim=aspp_hidden_dim,
343
+ aspp_num_steps=aspp_num_steps,
344
+ aspp_dropout=aspp_dropout,
345
+ aspp_num_neighbors=aspp_num_neighbors
346
+ )
347
+
348
+ # Initialize weights
349
+ self.post_init()
350
+
351
+
352
+ class AsteriskForCausalLM(LlamaForCausalLM):
353
+ """
354
+ Asterisk Causal LM with Hybrid ASPP-Attention architecture
355
+
356
+ Registered as: AsteriskForCausalLM
357
+ """
358
+
359
+ config_class = AsteriskConfig
360
+
361
+ def __init__(self, config: AsteriskConfig, hybrid_layer_indices: Optional[List[int]] = None, aspp_hidden_dim: Optional[int] = None, aspp_num_steps: int = 2, aspp_dropout: float = 0.1, aspp_num_neighbors: int = 2):
362
+ # Read all ASPP parameters from config if not explicitly provided
363
+ if hybrid_layer_indices is None and hasattr(config, 'hybrid_layer_indices'):
364
+ hybrid_layer_indices = config.hybrid_layer_indices
365
+ if aspp_hidden_dim is None and hasattr(config, 'aspp_hidden_dim'):
366
+ aspp_hidden_dim = config.aspp_hidden_dim
367
+ if hasattr(config, 'aspp_num_steps'):
368
+ aspp_num_steps = config.aspp_num_steps
369
+ if hasattr(config, 'aspp_dropout'):
370
+ aspp_dropout = config.aspp_dropout
371
+ if hasattr(config, 'aspp_num_neighbors'):
372
+ aspp_num_neighbors = config.aspp_num_neighbors
373
+
374
+ super().__init__(config)
375
+
376
+ # Replace model with Asterisk version
377
+ self.model = AsteriskLlamaModel(config, hybrid_layer_indices, aspp_hidden_dim, aspp_num_steps, aspp_dropout, aspp_num_neighbors)
378
+
379
+ # Store hybrid layer info in config for serialization
380
+ self.config.hybrid_layer_indices = hybrid_layer_indices
381
+
382
+ # Initialize weights
383
+ self.post_init()
384
+
385
+ @classmethod
386
+ def from_pretrained_base(
387
+ cls,
388
+ base_model_path: str,
389
+ config: Optional[AsteriskConfig] = None, # NEW: Accept pre-configured config
390
+ hybrid_layer_indices: Optional[List[int]] = None,
391
+ aspp_hidden_dim: Optional[int] = None,
392
+ aspp_num_steps: int = 2,
393
+ aspp_dropout: float = 0.1,
394
+ aspp_num_neighbors: int = 1, # Fixed at 1 for Union-Find (only parent)
395
+ # π-flow parameters
396
+ pi_flow: bool = False,
397
+ pi_flow_steps: int = 1,
398
+ pi_flow_scale: float = 0.2,
399
+ pi_flow_use_gate: bool = True,
400
+ **kwargs
401
+ ):
402
+ """
403
+ Load base model and convert to Asterisk architecture
404
+
405
+ Args:
406
+ base_model_path: Path to base SmolLM2 model
407
+ config: Pre-configured AsteriskConfig (if provided, other ASPP params are ignored)
408
+ hybrid_layer_indices: Which layers to make hybrid (None for all)
409
+ aspp_hidden_dim: Internal dimension for ASPP (None = use model hidden_size)
410
+ aspp_num_steps: Number of evolution steps K for ASPP (default: 2)
411
+ aspp_dropout: Dropout rate for ASPP regularization (default: 0.1)
412
+ aspp_num_neighbors: Number of neighbors for Union-Find (fixed at 1: only parent)
413
+ pi_flow: Enable π-flow refinement step (default: False)
414
+ pi_flow_steps: Number of flow refinement steps (default: 1)
415
+ pi_flow_scale: Initial flow scale parameter (default: 0.2)
416
+ pi_flow_use_gate: Use token-wise adaptive gating (default: True)
417
+ """
418
+ # Load base model
419
+ base_model = LlamaForCausalLM.from_pretrained(base_model_path, **kwargs)
420
+ base_config = base_model.config
421
+
422
+ # Use provided config or create new one
423
+ if config is not None:
424
+ # Use pre-configured config
425
+ asterisk_config = config
426
+ else:
427
+ # Create Asterisk config from base config with ASPP + π-flow params
428
+ asterisk_config = AsteriskConfig(
429
+ **base_config.to_dict(),
430
+ hybrid_layer_indices=hybrid_layer_indices,
431
+ aspp_hidden_dim=aspp_hidden_dim,
432
+ aspp_num_steps=aspp_num_steps,
433
+ aspp_dropout=aspp_dropout,
434
+ aspp_num_neighbors=aspp_num_neighbors,
435
+ pi_flow=pi_flow,
436
+ pi_flow_steps=pi_flow_steps,
437
+ pi_flow_scale=pi_flow_scale,
438
+ pi_flow_use_gate=pi_flow_use_gate,
439
+ )
440
+
441
+ # Create Asterisk model (config already contains all ASPP params)
442
+ asterisk_model = cls(asterisk_config)
443
+
444
+ # Transfer weights from base model (non-hybrid layers and embeddings)
445
+ asterisk_model.load_state_dict(base_model.state_dict(), strict=False)
446
+
447
+ print(f"✓ Converted base model to Asterisk architecture with Graph Propagation")
448
+ print(f" Hybrid layers: {asterisk_model.model.hybrid_layer_indices}")
449
+ aspp_dim_str = f"{asterisk_config.aspp_hidden_dim}" if asterisk_config.aspp_hidden_dim else f"{base_config.hidden_size} (full)"
450
+ print(f" ASPP config: dim={aspp_dim_str}, steps={asterisk_config.aspp_num_steps}, dropout={asterisk_config.aspp_dropout}, neighbors={asterisk_config.aspp_num_neighbors}")
451
+ if asterisk_config.pi_flow:
452
+ print(f" π-flow enabled: steps={asterisk_config.pi_flow_steps}, scale={asterisk_config.pi_flow_scale}, gate={asterisk_config.pi_flow_use_gate}")
453
+
454
+ return asterisk_model, base_model
455
+
456
+
457
+ # Register the model for AutoModel
458
+ AutoConfig.register("asterisk", AsteriskConfig)
459
+ AutoModelForCausalLM.register(AsteriskConfig, AsteriskForCausalLM)
460
+
461
+
462
+ def get_model_info(model):
463
+ """Print model architecture information"""
464
+ total_params = sum(p.numel() for p in model.parameters())
465
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
466
+
467
+ print(f" • Total parameters: {total_params:,}")
468
+ print(f" • Trainable parameters: {trainable_params:,}")
469
+ print(f" • Model size: {total_params * 4 / 1024**2:.2f} MB (fp32)")
470
+
471
+ if isinstance(model, AsteriskForCausalLM):
472
+ print(f" • Hybrid layer indices: {model.model.hybrid_layer_indices}")
473
+ print(f" • Number of hybrid layers: {len(model.model.hybrid_layer_indices)}")
README.md CHANGED
@@ -1,3 +1,467 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - sr
4
+ - en
5
+ license: apache-2.0
6
+ tags:
7
+ - text-generation
8
+ - reasoning
9
+ - serbian
10
+ - asterisk
11
+ - aspp
12
+ - hybrid-architecture
13
+ - multilingual
14
+ datasets:
15
+ - ODA-Mixture-100k
16
+ - ultrachat_200k_serbian
17
+ metrics:
18
+ - accuracy
19
+ - perplexity
20
+ base_model: Geilim-1B-Instruct
21
+ model-index:
22
+ - name: Geilim-1B-SR-Instruct
23
+ results: []
24
+ ---
25
+
26
+ # Geilim-1B-SR-Instruct
27
+
28
+ <div align="center">
29
+ <h3>🇷🇸 Serbian Reasoning Model - AI Democratization Project</h3>
30
+ <p><em>Bringing advanced reasoning capabilities to Serbian language</em></p>
31
+ </div>
32
+
33
+ ## Model Description
34
+
35
+ **Geilim-1B-SR-Instruct** is a 1.3B parameter Serbian reasoning model that combines:
36
+ - **Base**: Geilim-1B-Instruct (1B parameters, Llama-3 architecture, 16 layers)
37
+ - **Architecture**: Asterisk hybrid ASPP + Attention
38
+ - **Training**: 50% ODA-Mixture-100k (reasoning) + 50% UltraChat Serbian (conversations)
39
+ - **Goal**: Democratize AI by bringing reasoning to underrepresented languages
40
+
41
+ ### Key Features
42
+
43
+ - ✅ **Hybrid Architecture**: All 16 layers use ASPP + standard Attention
44
+ - ✅ **Graph-based Reasoning**: Union-Find structure with 6-step iterative propagation
45
+ - ✅ **π-flow Refinement**: 4-step continuous flow dynamics for enhanced reasoning
46
+ - ✅ **Bilingual**: Serbian language with preserved English reasoning capabilities
47
+ - ✅ **Efficient**: ~1.3B total parameters, trainable on 2x consumer GPUs
48
+
49
+ ## Model Details
50
+
51
+ ### Model Architecture
52
+
53
+ ```
54
+ Input → Embedding
55
+
56
+ Layers 0-15: Hybrid ASPP + Attention (ALL 16 layers)
57
+ ├─ ASPP Branch (Union-Find graph reasoning)
58
+ │ ├─ 6-step iterative propagation
59
+ │ ├─ Hidden dim: 512 (reduced from 2048)
60
+ │ └─ π-flow: 4-step refinement
61
+ └─ Attention Branch (standard self-attention)
62
+
63
+ Gated Fusion: output = gate * ASPP(x) + (1-gate) * Attention(x)
64
+
65
+ Output → LM Head
66
+ ```
67
+
68
+ ### Technical Specifications
69
+
70
+ - **Parameters**: ~1.3B (1B base + 300M ASPP/π-flow)
71
+ - **Layers**: 16 (all hybrid)
72
+ - **Hidden Size**: 2048
73
+ - **Attention Heads**: 32
74
+ - **KV Heads**: 8 (GQA)
75
+ - **Vocabulary**: 128,256 tokens
76
+ - **Context Length**: 131,072 tokens (with RoPE scaling)
77
+ - **Precision**: bfloat16
78
+
79
+ ### ASPP Configuration
80
+
81
+ - **Hidden Dim**: 512 (dimensionality reduction)
82
+ - **Iteration Steps**: 6
83
+ - **Dropout**: 0.15
84
+ - **Graph Structure**: Union-Find (parent-only connections)
85
+
86
+ ### π-flow Configuration
87
+
88
+ - **Steps**: 4
89
+ - **Scale**: 0.4
90
+ - **Gating**: Adaptive per-token
91
+ - **Purpose**: Multi-step refinement in probability space
92
+
93
+ ## Intended Use
94
+
95
+ ### Primary Use Cases
96
+
97
+ 1. **Serbian Language Tasks**:
98
+ - Conversational AI in Serbian
99
+ - Question answering in Serbian
100
+ - Text generation and completion
101
+
102
+ 2. **Reasoning Tasks**:
103
+ - Mathematical problem solving
104
+ - Code generation and debugging
105
+ - Step-by-step logical reasoning
106
+
107
+ 3. **Bilingual Applications**:
108
+ - Serbian-English translation assistance
109
+ - Cross-lingual reasoning tasks
110
+
111
+ ### Out-of-Scope Use
112
+
113
+ - Production-critical applications without further testing
114
+ - Tasks requiring real-time factual accuracy (model may hallucinate)
115
+ - Languages other than Serbian and English (limited support)
116
+
117
+ ## How to Use
118
+
119
+ ### Installation
120
+
121
+ ```bash
122
+ pip install torch transformers accelerate
123
+ ```
124
+
125
+ ### Basic Usage
126
+
127
+ ```python
128
+ import torch
129
+ from transformers import AutoTokenizer, AutoModelForCausalLM
130
+
131
+ # Load model and tokenizer
132
+ model_name = "NoesisLab/Geilim-1B-SR-Instruct"
133
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
134
+ model = AutoModelForCausalLM.from_pretrained(
135
+ model_name,
136
+ trust_remote_code=True,
137
+ torch_dtype=torch.bfloat16,
138
+ device_map="auto",
139
+ )
140
+
141
+ # Serbian conversation
142
+ messages = [
143
+ {"role": "user", "content": "Kakvu ulogu igraju nagrade i pozitivno pojačanje u dresuri Bigla i kako se mogu efikasno koristiti bez podsticanja lošeg ponašanja?"}
144
+ ]
145
+
146
+ # Apply chat template
147
+ input_text = tokenizer.apply_chat_template(
148
+ messages,
149
+ tokenize=False,
150
+ add_generation_prompt=True
151
+ )
152
+
153
+ # Tokenize
154
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
155
+
156
+ # Generate
157
+ outputs = model.generate(
158
+ **inputs,
159
+ max_new_tokens=200,
160
+ temperature=0.7,
161
+ top_p=0.9,
162
+ repetition_penalty=1.1,
163
+ do_sample=True,
164
+ )
165
+
166
+ # Decode
167
+ response = tokenizer.decode(
168
+ outputs[0][inputs['input_ids'].shape[1]:],
169
+ skip_special_tokens=True
170
+ )
171
+ print(response)
172
+ ```
173
+
174
+
175
+
176
+
177
+ ### Recommended Generation Parameters
178
+
179
+ ```python
180
+ generation_config = {
181
+ "max_new_tokens": 200,
182
+ "temperature": 0.7, # Balance creativity and coherence
183
+ "top_p": 0.9, # Nucleus sampling
184
+ "repetition_penalty": 1.1, # Reduce repetition
185
+ "do_sample": True,
186
+ }
187
+ ```
188
+
189
+
190
+ ## Training Data
191
+
192
+ ### Dataset Composition
193
+
194
+ The model was trained on a balanced mix of two datasets:
195
+
196
+ #### 1. ODA-Mixture-100k (50% - Reasoning Data)
197
+
198
+ **101,306 reasoning samples** across three domains:
199
+
200
+ - **Math** (50,244 samples): AM-Thinking-v1-Distilled-math
201
+ - Mathematical problem solving with step-by-step reasoning
202
+ - Format: instruction → response (reasoning trace) → final answer
203
+
204
+ - **Code** (50,245 samples): AM-Thinking-v1-Distilled-code
205
+ - Programming problems with detailed solutions
206
+ - Code generation, debugging, and explanation tasks
207
+
208
+ - **General** (817 samples): LIMO
209
+ - General reasoning tasks
210
+ - Logic puzzles, common sense reasoning
211
+
212
+ #### 2. UltraChat Serbian (50% - Language Data)
213
+
214
+ **207,588 high-quality Serbian conversations**:
215
+
216
+ - Translated from UltraChat 200k
217
+ - Multi-turn dialogues covering diverse topics
218
+ - Topics: science, culture, daily life, reasoning, education
219
+ - Format: `messages_srb` (Serbian), `messages_eng` (English reference)
220
+
221
+ ### Data Mixing Strategy
222
+
223
+ - **Balanced 50/50 split**: Preserve reasoning while learning Serbian
224
+ - **Automatic sampling**: Match smaller dataset size
225
+ - **Total samples**: ~100k (sampled from 202k available)
226
+ - **Train/Test split**: 95% / 5%
227
+
228
+ ## Training Procedure
229
+
230
+ ### Training Hyperparameters
231
+
232
+ - **Epochs**: 2
233
+ - **Batch Size**: 2 per device
234
+ - **Gradient Accumulation**: 8 steps (effective batch size = 16)
235
+ - **Learning Rate**: 5e-5
236
+ - **Warmup Ratio**: 0.1 (10% of training)
237
+ - **Weight Decay**: 0.05
238
+ - **Max Gradient Norm**: 1.0
239
+ - **Optimizer**: AdamW
240
+ - **Precision**: bfloat16 mixed precision
241
+ - **Gradient Checkpointing**: Enabled
242
+ - **Max Sequence Length**: 2048 tokens
243
+
244
+ ### Training Infrastructure
245
+
246
+ - **Framework**: HuggingFace Transformers + TRL SFTTrainer
247
+ - **Distributed Training**: Accelerate (multi-GPU)
248
+ - **GPUs**: 1x RTX PRO 6000
249
+ - **Training Time**: ~6-8 hours
250
+ - **Memory per GPU**: ~15GB
251
+
252
+
253
+ ## Evaluation
254
+
255
+ ### Qualitative Evaluation
256
+
257
+ The model demonstrates:
258
+ - ✅ Fluent Serbian language generation
259
+ - ✅ Step-by-step reasoning in Serbian
260
+ - ✅ Mathematical problem solving
261
+ - ✅ Code understanding and generation
262
+ - ✅ Multi-turn conversation capabilities
263
+
264
+
265
+
266
+ ## Limitations and Biases
267
+
268
+ ### Known Limitations
269
+
270
+ 1. **Language Coverage**: Primarily trained on Serbian and English; limited support for other languages
271
+ 2. **Factual Accuracy**: May generate plausible but incorrect information (hallucination)
272
+ 3. **Context Length**: While supporting 131k tokens, performance may degrade on very long contexts
273
+ 4. **Domain Specificity**: Best performance on conversational and reasoning tasks; may struggle with highly specialized domains
274
+ 5. **Training Data**: Limited to ~100k samples; may not cover all Serbian language variations
275
+
276
+ ### Potential Biases
277
+
278
+ - **Translation Bias**: Serbian data is translated from English, may not reflect natural Serbian expressions
279
+ - **Domain Bias**: Reasoning data focuses on math and code; may be less effective on other domains
280
+ - **Cultural Bias**: Training data may reflect Western cultural perspectives
281
+
282
+ ### Recommendations
283
+
284
+ - Verify factual claims with authoritative sources
285
+ - Test thoroughly before deployment in production
286
+ - Monitor for biased or inappropriate outputs
287
+ - Consider fine-tuning on domain-specific data for specialized applications
288
+
289
+ ## Ethical Considerations
290
+
291
+ ### AI Democratization
292
+
293
+ This model is part of an effort to democratize AI by bringing advanced capabilities to underrepresented languages. Serbian, despite having ~12 million speakers, has limited AI resources compared to high-resource languages.
294
+
295
+ ### Responsible Use
296
+
297
+ Users should:
298
+ - Be aware of potential biases and limitations
299
+ - Not use for malicious purposes (misinformation, harassment, etc.)
300
+ - Respect privacy and data protection regulations
301
+ - Consider societal impact of deployments
302
+
303
+ ### Environmental Impact
304
+
305
+ - **Training**: ~6-8 hours on 2x A100 GPUs
306
+ - **Carbon Footprint**: Estimated ~5-10 kg CO2eq (depends on energy source)
307
+ - **Inference**: Efficient at 1.3B parameters, suitable for edge deployment
308
+
309
+
310
+ ## Technical Details
311
+
312
+ ### Asterisk Architecture
313
+
314
+ The model uses the **Asterisk** architecture, which combines:
315
+
316
+ 1. **ASPP (Adjacency-Structured Parallel Propagation)**:
317
+ - Graph-based reasoning with Union-Find structure
318
+ - Each token maintains parent pointer: `parent[i] = i-1`
319
+ - Iterative message passing: `h_i^(t+1) = φ(h_i^(t), h_parent[i])`
320
+ - 6 propagation steps per layer
321
+
322
+ 2. **π-flow Refinement**:
323
+ - Continuous flow dynamics: `h' = h + α * v(h)`
324
+ - Learnable velocity field for multi-step refinement
325
+ - Adaptive per-token gating
326
+ - 4 refinement steps per layer
327
+
328
+ 3. **Hybrid Fusion**:
329
+ - Parallel execution of ASPP and standard Attention
330
+ - Gated combination: `output = gate * ASPP(x) + (1-gate) * Attention(x)`
331
+ - Applied to all 16 layers
332
+
333
+ ### Model Configuration
334
+
335
+ ```json
336
+ {
337
+ "model_type": "asterisk",
338
+ "hidden_size": 2048,
339
+ "num_hidden_layers": 16,
340
+ "num_attention_heads": 32,
341
+ "num_key_value_heads": 8,
342
+ "intermediate_size": 8192,
343
+ "vocab_size": 128256,
344
+ "max_position_embeddings": 131072,
345
+
346
+ "aspp_hidden_dim": 512,
347
+ "aspp_num_steps": 6,
348
+ "aspp_dropout": 0.15,
349
+ "aspp_num_neighbors": 1,
350
+
351
+ "pi_flow": true,
352
+ "pi_flow_steps": 4,
353
+ "pi_flow_scale": 0.4,
354
+ "pi_flow_use_gate": true,
355
+
356
+ "hybrid_layer_indices": null
357
+ }
358
+ ```
359
+
360
+ ## Comparison with Other Models
361
+
362
+ | Model | Base | Params | Layers | Language | Reasoning | Architecture |
363
+ |-------|------|--------|--------|----------|-----------|--------------|
364
+ | SmolLM2-135M | - | 135M | 30 | English | ❌ | Transformer |
365
+ | Asterisk | SmolLM2 | 171M | 30 | English | ✅ ASPP | Hybrid |
366
+ | **Geilim-1B-SR** | Geilim-1B | 1.3B | 16 | Serbian | ✅ ASPP | Hybrid |
367
+
368
+ ### Advantages
369
+
370
+ - ✅ **Efficient Size**: 1.3B parameters, suitable for consumer hardware
371
+ - ✅ **Full Hybrid**: All 16 layers use ASPP + Attention
372
+ - ✅ **Bilingual**: Serbian + English capabilities
373
+ - ✅ **Reasoning**: Math, code, and general reasoning
374
+ - ✅ **Fast Training**: ~6-8 hours on 2x A100
375
+ - ✅ **Low Memory**: ~3GB inference, ~20GB training per GPU
376
+
377
+ ## Hardware Requirements
378
+
379
+ ### Inference
380
+
381
+ - **Minimum**: 1x GPU with 8GB VRAM (e.g., RTX 3060)
382
+ - **Recommended**: 1x GPU with 16GB+ VRAM (e.g., RTX 4080, A100)
383
+ - **CPU Only**: Possible but slow (~10-20x slower)
384
+
385
+ ### Training
386
+
387
+ - **Minimum**: 2x GPU with 24GB VRAM (e.g., RTX 3090/4090)
388
+ - **Recommended**: 2x GPU with 40GB VRAM (e.g., A100)
389
+ - **Memory**: ~20GB per GPU with gradient checkpointing
390
+
391
+ ## Model Card Authors
392
+
393
+ - **NoesisLab**
394
+
395
+ ## Citation
396
+
397
+ If you use this model in your research or applications, please cite:
398
+
399
+ ```bibtex
400
+ @software{geilim_1b_sr_2026,
401
+ title={Geilim-1B-SR-Instruct: Serbian Reasoning Model with Asterisk Architecture},
402
+ author={NoesisLab},
403
+ year={2026},
404
+ url={https://huggingface.co/NoesisLab/Geilim-1B-SR-Instruct},
405
+ note={AI Democratization - Bringing reasoning to underrepresented languages}
406
+ }
407
+ ```
408
+
409
+ ### Related Papers
410
+
411
+ ```bibtex
412
+ @article{asterisk_2026,
413
+ title={Asterisk: Hybrid ASPP-Attention Architecture for Efficient Reasoning},
414
+ author={NoesisLab},
415
+ year={2026},
416
+ note={Graph-based reasoning with Union-Find propagation}
417
+ }
418
+ ```
419
+
420
+ ## Acknowledgments
421
+
422
+ - **Geilim-1B-Instruct**: Base model (Llama-3 architecture, 1B parameters)
423
+ - **ODA-Mixture-100k**: Reasoning dataset (Math, Code, General)
424
+ - **UltraChat**: High-quality conversation dataset
425
+ - **Serbian NLP Community**: Language support and feedback
426
+ - **HuggingFace**: Transformers library and model hosting
427
+ - **Accelerate**: Distributed training framework
428
+
429
+ ## License
430
+
431
+ This model is released under the **Apache 2.0 License**, same as the base model.
432
+
433
+ ```
434
+ Copyright 2026 Asterisk Project
435
+
436
+ Licensed under the Apache License, Version 2.0 (the "License");
437
+ you may not use this file except in compliance with the License.
438
+ You may obtain a copy of the License at
439
+
440
+ http://www.apache.org/licenses/LICENSE-2.0
441
+
442
+ Unless required by applicable law or agreed to in writing, software
443
+ distributed under the License is distributed on an "AS IS" BASIS,
444
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
445
+ See the License for the specific language governing permissions and
446
+ limitations under the License.
447
+ ```
448
+
449
+
450
+ ## Version History
451
+
452
+ - **v1.0** (2026-02): Initial release
453
+ - 1.3B parameters (1B base + 300M ASPP/π-flow)
454
+ - Trained on 100k samples (50% ODA-Mixture + 50% UltraChat Serbian)
455
+ - All 16 layers use hybrid ASPP + Attention
456
+ - Supports Serbian and English
457
+
458
+ ## Contact and Support
459
+
460
+
461
+ - **Email**: lizx93@mail2.sysu.edu.cn
462
+
463
+ ---
464
+
465
+ <div align="center">
466
+ <h3>🇷🇸 Democratizing AI, one language at a time!</h3>
467
+ <p><em>Making advanced AI technology accessible to every language</em></p>
handler.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Dict, List, Union
5
+
6
+ import torch
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+
10
+ Json = Dict[str, Any]
11
+ Messages = List[Dict[str, str]] # [{"role":"user|assistant|system", "content":"..."}]
12
+
13
+
14
+ def _is_messages(x: Any) -> bool:
15
+ return (
16
+ isinstance(x, list)
17
+ and len(x) > 0
18
+ and all(isinstance(m, dict) and "role" in m and "content" in m for m in x)
19
+ )
20
+
21
+
22
+ class EndpointHandler:
23
+ """
24
+ Hugging Face Inference Endpoints custom handler.
25
+ Expects:
26
+ - request body is a dict
27
+ - always contains `inputs`
28
+ - may contain `parameters` for generation
29
+ """
30
+
31
+ def __init__(self, model_dir: str):
32
+ self.model_dir = model_dir
33
+
34
+ # Pick dtype/device
35
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
36
+ if self.device == "cuda":
37
+ # bfloat16 is usually safe on A100/H100; if your instance doesn't support bf16, change to float16
38
+ self.dtype = torch.bfloat16
39
+ else:
40
+ self.dtype = torch.float32
41
+
42
+ # IMPORTANT: trust_remote_code=True because repo contains AsteriskForCausalLM.py + auto_map
43
+ self.tokenizer = AutoTokenizer.from_pretrained(
44
+ model_dir,
45
+ trust_remote_code=True,
46
+ use_fast=True,
47
+ )
48
+
49
+ # Make sure pad token exists (your config uses pad_token_id=2 which equals eos_token_id in many llama-like models)
50
+ if self.tokenizer.pad_token_id is None and self.tokenizer.eos_token_id is not None:
51
+ self.tokenizer.pad_token = self.tokenizer.eos_token
52
+
53
+ self.model = AutoModelForCausalLM.from_pretrained(
54
+ model_dir,
55
+ trust_remote_code=True,
56
+ torch_dtype=self.dtype,
57
+ device_map="auto" if self.device == "cuda" else None,
58
+ )
59
+
60
+ if self.device != "cuda":
61
+ self.model.to(self.device)
62
+
63
+ self.model.eval()
64
+
65
+ @torch.inference_mode()
66
+ def __call__(self, data: Json) -> Union[Json, List[Json]]:
67
+ inputs = data.get("inputs", "")
68
+ params = data.get("parameters", {}) or {}
69
+
70
+ # Generation defaults (can be overridden via `parameters`)
71
+ max_new_tokens = int(params.get("max_new_tokens", 256))
72
+ temperature = float(params.get("temperature", 0.7))
73
+ top_p = float(params.get("top_p", 0.95))
74
+ top_k = int(params.get("top_k", 0))
75
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
76
+
77
+ do_sample = bool(params.get("do_sample", temperature > 0))
78
+ num_beams = int(params.get("num_beams", 1))
79
+
80
+ def _one(item: Any) -> Json:
81
+ # Accept:
82
+ # 1) string prompt
83
+ # 2) messages list: [{"role":"user","content":"..."}]
84
+ # 3) dict {"messages":[...]} (common chat style)
85
+ if isinstance(item, dict) and "messages" in item:
86
+ item = item["messages"]
87
+
88
+ if _is_messages(item):
89
+ # Chat template path exists in repo; tokenizer.apply_chat_template will use it if configured
90
+ input_ids = self.tokenizer.apply_chat_template(
91
+ item,
92
+ return_tensors="pt",
93
+ add_generation_prompt=True,
94
+ )
95
+ else:
96
+ if not isinstance(item, str):
97
+ item = str(item)
98
+ enc = self.tokenizer(item, return_tensors="pt")
99
+ input_ids = enc["input_ids"]
100
+
101
+ input_ids = input_ids.to(self.model.device)
102
+ input_len = input_ids.shape[-1]
103
+
104
+ gen_ids = self.model.generate(
105
+ input_ids=input_ids,
106
+ max_new_tokens=max_new_tokens,
107
+ do_sample=do_sample,
108
+ temperature=temperature if do_sample else None,
109
+ top_p=top_p if do_sample else None,
110
+ top_k=top_k if do_sample and top_k > 0 else None,
111
+ num_beams=num_beams,
112
+ repetition_penalty=repetition_penalty,
113
+ pad_token_id=self.tokenizer.pad_token_id,
114
+ eos_token_id=self.tokenizer.eos_token_id,
115
+ )
116
+
117
+ # Only return newly generated tokens
118
+ new_tokens = gen_ids[0, input_len:]
119
+ text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
120
+ return {"generated_text": text}
121
+
122
+ # Batch support
123
+ if isinstance(inputs, list) and not _is_messages(inputs):
124
+ return [_one(x) for x in inputs]
125
+ else:
126
+ return _one(inputs)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers==4.57.6
2
+ torch
3
+ accelerate