gary-boon Claude commited on
Commit
5aed1a9
·
1 Parent(s): 992dc8c

Add layer_stride parameter for PromptDiff optimization

Browse files

- Add layer_stride parameter to control which layers are captured
- Default to 1 (all layers) for AttentionExplorer
- PromptDiff can use layer_stride=2 for every other layer
- Reduces matrix count from 20 to 10 for better visualization fit

🤖 Generated with Claude Code (https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

backend/__pycache__/model_service.cpython-310.pyc CHANGED
Binary files a/backend/__pycache__/model_service.cpython-310.pyc and b/backend/__pycache__/model_service.cpython-310.pyc differ
 
backend/model_service.py CHANGED
@@ -47,6 +47,7 @@ class GenerationRequest(BaseModel):
47
  top_p: Optional[float] = None
48
  extract_traces: bool = True
49
  sampling_rate: float = 0.005
 
50
 
51
  class AblatedGenerationRequest(BaseModel):
52
  prompt: str
@@ -500,7 +501,8 @@ class ModelManager:
500
  temperature: float = 0.7,
501
  top_k: Optional[int] = None,
502
  top_p: Optional[float] = None,
503
- sampling_rate: float = 0.005
 
504
  ) -> Dict[str, Any]:
505
  """Generate text with trace extraction"""
506
  if not self.model or not self.tokenizer:
@@ -643,8 +645,8 @@ class ModelManager:
643
  # Clear previous partial traces and add complete ones
644
  traces = [] # Reset traces to only include complete attention patterns
645
 
646
- # Capture ALL layers for complete visualization
647
- for layer_idx in range(num_layers):
648
  try:
649
  # Get all token IDs (prompt + generated)
650
  all_token_ids = inputs["input_ids"][0].tolist()
@@ -892,7 +894,8 @@ async def generate(request: GenerationRequest, authenticated: bool = Depends(ver
892
  temperature=request.temperature,
893
  top_k=request.top_k,
894
  top_p=request.top_p,
895
- sampling_rate=request.sampling_rate if request.extract_traces else 0
 
896
  )
897
  return result
898
 
 
47
  top_p: Optional[float] = None
48
  extract_traces: bool = True
49
  sampling_rate: float = 0.005
50
+ layer_stride: int = 1 # 1 = all layers, 2 = every other layer, etc.
51
 
52
  class AblatedGenerationRequest(BaseModel):
53
  prompt: str
 
501
  temperature: float = 0.7,
502
  top_k: Optional[int] = None,
503
  top_p: Optional[float] = None,
504
+ sampling_rate: float = 0.005,
505
+ layer_stride: int = 1 # 1 = all layers, 2 = every other layer, etc.
506
  ) -> Dict[str, Any]:
507
  """Generate text with trace extraction"""
508
  if not self.model or not self.tokenizer:
 
645
  # Clear previous partial traces and add complete ones
646
  traces = [] # Reset traces to only include complete attention patterns
647
 
648
+ # Capture layers based on stride (1 = all, 2 = every other, etc.)
649
+ for layer_idx in range(0, num_layers, layer_stride):
650
  try:
651
  # Get all token IDs (prompt + generated)
652
  all_token_ids = inputs["input_ids"][0].tolist()
 
894
  temperature=request.temperature,
895
  top_k=request.top_k,
896
  top_p=request.top_p,
897
+ sampling_rate=request.sampling_rate if request.extract_traces else 0,
898
+ layer_stride=request.layer_stride
899
  )
900
  return result
901
 
components/ui/tooltip.tsx ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "use client"
2
+
3
+ import * as React from "react"
4
+ import * as TooltipPrimitive from "@radix-ui/react-tooltip"
5
+
6
+ import { cn } from "@/lib/utils"
7
+
8
+ const TooltipProvider = TooltipPrimitive.Provider
9
+
10
+ const Tooltip = TooltipPrimitive.Root
11
+
12
+ const TooltipTrigger = TooltipPrimitive.Trigger
13
+
14
+ const TooltipContent = React.forwardRef<
15
+ React.ElementRef<typeof TooltipPrimitive.Content>,
16
+ React.ComponentPropsWithoutRef<typeof TooltipPrimitive.Content>
17
+ >(({ className, sideOffset = 4, ...props }, ref) => (
18
+ <TooltipPrimitive.Content
19
+ ref={ref}
20
+ sideOffset={sideOffset}
21
+ className={cn(
22
+ "z-50 overflow-hidden rounded-md bg-gray-900 px-3 py-1.5 text-sm text-gray-100 shadow-md animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 border border-gray-800",
23
+ className
24
+ )}
25
+ {...props}
26
+ />
27
+ ))
28
+ TooltipContent.displayName = TooltipPrimitive.Content.displayName
29
+
30
+ export { Tooltip, TooltipTrigger, TooltipContent, TooltipProvider }