/* Attn API and Types */ import * as d3 from "d3"; import URLHandler from "../utils/URLHandler"; import {cleanSpecials} from "../utils/Util"; import {AnalyzeResponse, AnalyzeResult, TokenWithOffset} from "./generatedSchemas"; export type FrontendToken = TokenWithOffset & { bpe_merged?: boolean }; export interface FrontendAnalyzeResult extends AnalyzeResult { bpe_strings: FrontendToken[]; originalTokens: FrontendToken[]; mergedTokens: FrontendToken[]; originalToMergedMap: number[]; originalText: string; // 前端注入的原始文本(来自 request.text) } // AnalyzedText 已废弃,请使用 FrontendAnalyzeResult export type AnalyzedText = FrontendAnalyzeResult; // @deprecated 使用 FrontendAnalyzeResult // 类型别名:AnalysisData 用于 demo 存储场景(保存后的数据),AnalyzeResponse 用于 API 分析场景(保存前的数据) export type AnalysisData = AnalyzeResponse; export type { AnalyzeResponse, TokenWithOffset }; export class TextAnalysisAPI { private adminToken: string | null = null; constructor(private baseURL: string = null) { if (this.baseURL == null) { this.baseURL = URLHandler.basicURL(); } } /** * 设置admin token */ public setAdminToken(token: string | null): void { this.adminToken = token; } /** * 获取请求头(如果有admin token,自动添加到请求头) */ private getHeaders(additionalHeaders?: Record): Record { const headers: Record = { "Content-type": "application/json; charset=UTF-8", ...additionalHeaders }; // 如果有admin token,自动添加 if (this.adminToken) { headers['X-Admin-Token'] = this.adminToken; } return headers; } public list_demos(path?: string): Promise<{ path: string, items: Array<{type: 'folder'|'file', name: string, path: string}> }> { const url = this.baseURL + '/api/list_demos' + (path ? `?path=${encodeURIComponent(path)}` : ''); return d3.json(url); } public save_demo(name: string, data: AnalyzeResponse, path: string = '/', overwrite: boolean = false): Promise<{ success: boolean, exists?: boolean, message?: string, file?: string }> { return d3.json(this.baseURL + '/api/save_demo', { method: "POST", body: JSON.stringify({ name, data, path, overwrite }), headers: this.getHeaders() }); } public delete_demo(file: string): Promise<{ success: boolean, message?: string }> { return d3.json(this.baseURL + '/api/delete_demo', { method: "POST", body: JSON.stringify({ file }), headers: this.getHeaders() }); } public move_demo(file: string, targetPath: string): Promise<{ success: boolean, message?: string }> { return d3.json(this.baseURL + '/api/move_demo', { method: "POST", body: JSON.stringify({ file, target_path: targetPath }), headers: this.getHeaders() }); } public move_folder(path: string, targetPath: string): Promise<{ success: boolean, message?: string }> { return d3.json(this.baseURL + '/api/move_demo', { method: "POST", body: JSON.stringify({ path, target_path: targetPath }), headers: this.getHeaders() }); } public rename_demo(file: string, newName: string): Promise<{ success: boolean, message?: string }> { return d3.json(this.baseURL + '/api/rename_demo', { method: "POST", body: JSON.stringify({ file, new_name: newName }), headers: this.getHeaders() }); } public rename_folder(path: string, newName: string): Promise<{ success: boolean, message?: string }> { return d3.json(this.baseURL + '/api/rename_folder', { method: "POST", body: JSON.stringify({ path, new_name: newName }), headers: this.getHeaders() }); } public delete_folder(path: string): Promise<{ success: boolean, message?: string }> { return d3.json(this.baseURL + '/api/delete_folder', { method: "POST", body: JSON.stringify({ path }), headers: this.getHeaders() }); } public list_all_folders(): Promise<{ folders: string[] }> { return d3.json(this.baseURL + '/api/list_all_folders'); } public create_folder(parentPath: string, folderName: string): Promise<{ success: boolean, message?: string }> { return d3.json(this.baseURL + '/api/create_folder', { method: "POST", body: JSON.stringify({ parent_path: parentPath, folder_name: folderName }), headers: this.getHeaders() }); } /** * 构建分析请求的 payload */ private buildAnalyzePayload( model: string, text: string, bitmask: number[] = null, stream: boolean = false ): any { const payload: any = { model, text: cleanSpecials(text) }; if (bitmask) { payload['bitmask'] = bitmask; } if (stream) { payload['stream'] = true; } return payload; } public analyze( model: string, text: string, bitmask: number[] = null, stream: boolean = false, onProgress?: (step: number, totalSteps: number, stage: string, percentage?: number) => void ): Promise { // 如果启用流式响应,使用SSE方式 if (stream) { return this.analyzeWithProgress(model, text, onProgress); } // 否则使用传统的JSON响应 const payload = this.buildAnalyzePayload(model, text, bitmask, stream); return d3.json(this.baseURL + '/api/analyze', { method: "POST", body: JSON.stringify(payload), headers: { "Content-type": "application/json; charset=UTF-8" } }).then((response: any) => { // 检查统一的错误格式 if (response && response.success === false) { throw new Error(response.message || '分析失败'); } return response as AnalyzeResponse; }); } /** * 从 URL 提取文本内容 * * @param url 要提取文本的 URL * @returns Promise<{success: boolean, text?: string, url?: string, char_count?: number, message?: string}> */ public fetchUrlText(url: string): Promise<{success: boolean, text?: string, url?: string, char_count?: number, message?: string}> { return d3.json(this.baseURL + '/api/fetch_url', { method: "POST", body: JSON.stringify({ url }), headers: { "Content-type": "application/json; charset=UTF-8" } }).then((response: any) => { // 检查统一的错误格式 if (response && response.success === false) { throw new Error(response.message || 'URL 文本提取失败'); } return response; }); } /** * 获取可用模型列表 */ public getAvailableModels(): Promise<{ success: boolean, models: string[] }> { return d3.json(this.baseURL + '/api/available_models'); } /** * 获取当前模型 */ public getCurrentModel(): Promise<{ success: boolean, model: string, loading: boolean, device_type: 'cpu' | 'cuda' | 'mps', use_int8: boolean, use_bfloat16: boolean }> { return d3.json(this.baseURL + '/api/current_model'); } /** * 切换模型(需要管理员权限) */ public switchModel( model: string, use_int8?: boolean, use_bfloat16?: boolean ): Promise<{ success: boolean, message?: string, model?: string }> { return d3.json(this.baseURL + '/api/switch_model', { method: "POST", body: JSON.stringify({ model, use_int8: use_int8 || false, use_bfloat16: use_bfloat16 || false }), headers: this.getHeaders() }); } /** * 使用SSE流式分析文本,支持进度回调(内部方法) * * @param model 模型名称 * @param text 要分析的文本 * @param onProgress 进度回调函数,接收 (step: number, totalSteps: number, stage: string, percentage?: number) 参数 * @returns Promise */ private analyzeWithProgress( model: string, text: string, onProgress?: (step: number, totalSteps: number, stage: string, percentage?: number) => void ): Promise { return new Promise((resolve, reject) => { const payload = this.buildAnalyzePayload(model, text, null, true); // 使用fetch发送POST请求,然后通过ReadableStream接收SSE fetch(this.baseURL + '/api/analyze', { method: 'POST', headers: { 'Content-Type': 'application/json; charset=UTF-8' }, body: JSON.stringify(payload) }).then(response => { if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); } const reader = response.body.getReader(); const decoder = new TextDecoder(); let buffer = ''; const readChunk = (): Promise => { return reader.read().then(({ done, value }) => { if (done) { if (buffer.trim()) { // 处理剩余的缓冲区数据 this.processSSEMessage(buffer, onProgress, resolve, reject); } return; } buffer += decoder.decode(value, { stream: true }); const lines = buffer.split('\n'); buffer = lines.pop() || ''; // 保留最后不完整的行 for (const line of lines) { if (line.startsWith('data: ')) { const data = line.slice(6); // 移除 'data: ' 前缀 this.processSSEMessage(data, onProgress, resolve, reject); } } return readChunk(); }); }; return readChunk(); }).catch(error => { reject(error); }); }); } /** * 处理SSE消息 */ private processSSEMessage( data: string, onProgress: (step: number, totalSteps: number, stage: string, percentage?: number) => void, resolve: (value: AnalyzeResponse) => void, reject: (reason?: any) => void ): void { try { const parsed = JSON.parse(data); if (parsed.type === 'progress') { // 进度更新 if (onProgress) { onProgress(parsed.step, parsed.total_steps, parsed.stage, parsed.percentage); } } else if (parsed.type === 'result') { // 最终结果,检查统一的错误格式 const resultData = parsed.data; if (resultData && resultData.success === false) { reject(new Error(resultData.message || '分析失败')); } else { resolve(resultData as AnalyzeResponse); } } else if (parsed.type === 'error') { // 错误 reject(new Error(parsed.message || '分析失败')); } } catch (e) { // 忽略解析错误(可能是部分数据) console.warn('Failed to parse SSE message:', e, data); } } }