File size: 6,602 Bytes
0ce9643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import type { ModelInfo } from '../types/model';
import type { ActivationData, AnomalyData, CircuitData, SAEData, SAEInfoResponse, StructuralData, WeightData } from '../types/scan';
import type { PerturbResult, PatchResult } from '../types/perturb';
import type { CausalTraceResult } from '../types/causalTrace';
import type { DiagnosticReport, ReportRequest } from '../types/report';
import type { BatteryResult, TestCase } from '../types/battery';
import type { TokenStatus, CacheStatus, HubSearchResult } from '../types/settings';

const BASE_URL = import.meta.env.VITE_API_BASE_URL || '/api';

class ApiError extends Error {
  status: number;
  constructor(status: number, message: string) {
    super(message);
    this.name = 'ApiError';
    this.status = status;
  }
}

async function request<T>(path: string, options?: RequestInit): Promise<T> {
  const res = await fetch(`${BASE_URL}${path}`, {
    headers: { 'Content-Type': 'application/json' },
    ...options,
  });
  if (!res.ok) {
    const body = await res.json().catch(() => ({ detail: res.statusText }));
    throw new ApiError(res.status, body.detail || res.statusText);
  }
  return res.json();
}

export interface ModelListEntry {
  model_id: string;
  display_name: string;
  family: string;
  params: string;
  tl_compat: boolean;
  gated: boolean;
  is_loaded: boolean;
  source?: 'registry' | 'dynamic';
}

export const api = {
  model: {
    list: () => request<ModelListEntry[]>('/model/list'),
    load: (model_id: string, device = 'auto') =>
      request<ModelInfo>('/model/load', {
        method: 'POST',
        body: JSON.stringify({ model_id, device }),
      }),
    info: () => request<ModelInfo>('/model/info'),
    unload: () => request<{ status: string }>('/model/unload', { method: 'DELETE' }),
    search: (q: string, limit = 20, tlOnly = false) =>
      request<HubSearchResult[]>(
        `/model/search?q=${encodeURIComponent(q)}&limit=${limit}&tl_only=${tlOnly}`,
      ),
  },
  scan: {
    structural: (signal?: AbortSignal) =>
      request<StructuralData>('/scan/structural', { method: 'POST', signal }),
    weights: (layers?: string[], signal?: AbortSignal) =>
      request<WeightData>('/scan/weights', {
        method: 'POST',
        body: JSON.stringify({ layers: layers ?? null }),
        signal,
      }),
    activation: (prompt: string, signal?: AbortSignal) =>
      request<ActivationData>('/scan/activation', {
        method: 'POST',
        body: JSON.stringify({ prompt }),
        signal,
      }),
    circuits: (prompt: string, targetTokenIdx = -1, signal?: AbortSignal) =>
      request<CircuitData>('/scan/circuits', {
        method: 'POST',
        body: JSON.stringify({ prompt, target_token_idx: targetTokenIdx }),
        signal,
      }),
    anomaly: (prompt: string, signal?: AbortSignal) =>
      request<AnomalyData>('/scan/anomaly', {
        method: 'POST',
        body: JSON.stringify({ prompt }),
        signal,
      }),
  },
  perturb: {
    zero: (component: string, prompt: string) =>
      request<PerturbResult>('/perturb/zero', {
        method: 'POST',
        body: JSON.stringify({ component, prompt }),
      }),
    amplify: (component: string, prompt: string, factor = 2.0) =>
      request<PerturbResult>('/perturb/amplify', {
        method: 'POST',
        body: JSON.stringify({ component, factor, prompt }),
      }),
    ablate: (component: string, prompt: string) =>
      request<PerturbResult>('/perturb/ablate', {
        method: 'POST',
        body: JSON.stringify({ component, prompt }),
      }),
    patch: (cleanPrompt: string, corruptPrompt: string, component: string, targetIdx = -1) =>
      request<PatchResult>('/perturb/patch', {
        method: 'POST',
        body: JSON.stringify({
          clean_prompt: cleanPrompt,
          corrupt_prompt: corruptPrompt,
          component,
          target_token_idx: targetIdx,
        }),
      }),
    causalTrace: (cleanPrompt: string, corruptPrompt: string, targetIdx = -1) =>
      request<CausalTraceResult>('/perturb/causal-trace', {
        method: 'POST',
        body: JSON.stringify({
          clean_prompt: cleanPrompt,
          corrupt_prompt: corruptPrompt,
          target_token_idx: targetIdx,
        }),
      }),
    reset: () =>
      request<{ status: string }>('/perturb/reset', { method: 'POST' }),
  },
  report: {
    generate: (req: ReportRequest = {}) =>
      request<DiagnosticReport>('/report/generate', {
        method: 'POST',
        body: JSON.stringify(req),
      }),
  },
  battery: {
    run: (categories?: string[], locale?: string, includeSae?: boolean, saeLayer?: number | null) =>
      request<BatteryResult>('/battery/run', {
        method: 'POST',
        body: JSON.stringify({
          categories: categories ?? null,
          locale: locale ?? 'en',
          include_sae: includeSae ?? false,
          sae_layer: saeLayer ?? null,
        }),
      }),
    tests: () => request<TestCase[]>('/battery/tests'),
  },
  sae: {
    info: () => request<SAEInfoResponse>('/sae/info'),
    support: () => request<Record<string, boolean>>('/sae/support'),
    scan: (prompt: string, layerIdx: number, topK = 20) =>
      request<SAEData>('/sae/scan', {
        method: 'POST',
        body: JSON.stringify({ prompt, layer_idx: layerIdx, top_k: topK }),
      }),
  },
  collab: {
    create: (displayName = 'Host') =>
      request<{ session_id: string; host_id: string; join_url: string }>(
        `/collab/create?display_name=${encodeURIComponent(displayName)}`,
        { method: 'POST' },
      ),
    get: (sessionId: string) =>
      request<{
        session_id: string;
        host_id: string;
        host_name: string;
        participant_count: number;
      }>(`/collab/${sessionId}`),
    delete: (sessionId: string) =>
      request<{ status: string }>(`/collab/${sessionId}`, { method: 'DELETE' }),
    list: () =>
      request<
        Array<{
          session_id: string;
          host_name: string;
          participant_count: number;
          created_at: string;
        }>
      >('/collab/list'),
  },
  settings: {
    updateToken: (token: string) =>
      request<TokenStatus>('/settings/token', {
        method: 'POST',
        body: JSON.stringify({ token }),
      }),
    clearToken: () =>
      request<TokenStatus>('/settings/token', { method: 'DELETE' }),
    tokenStatus: () => request<TokenStatus>('/settings/token/status'),
    cacheStatus: () => request<CacheStatus>('/settings/cache'),
    clearCache: () =>
      request<{ status: string }>('/settings/cache', { method: 'DELETE' }),
  },
};