|
|
|
|
|
export class HuggingFaceService { |
|
|
constructor(token) { |
|
|
this.token = token; |
|
|
} |
|
|
|
|
|
async streamChatCompletion(messages, modelConfig, onChunk, onComplete, onError) { |
|
|
try { |
|
|
console.log('Starting chat completion with model:', modelConfig.endpoint); |
|
|
|
|
|
|
|
|
const response = await fetch( |
|
|
'https://api-inference.huggingface.co/models/' + modelConfig.endpoint, |
|
|
{ |
|
|
method: 'POST', |
|
|
headers: { |
|
|
'Authorization': `Bearer ${this.token}`, |
|
|
'Content-Type': 'application/json', |
|
|
}, |
|
|
body: JSON.stringify({ |
|
|
inputs: this.formatMessagesForInference(messages), |
|
|
parameters: { |
|
|
max_new_tokens: 1024, |
|
|
temperature: 0.7, |
|
|
top_p: 0.9, |
|
|
do_sample: true, |
|
|
return_full_text: false |
|
|
}, |
|
|
options: { |
|
|
wait_for_model: true, |
|
|
use_cache: false |
|
|
}, |
|
|
stream: true |
|
|
}) |
|
|
} |
|
|
); |
|
|
|
|
|
if (!response.ok) { |
|
|
const errorText = await response.text(); |
|
|
console.error('API Error:', response.status, errorText); |
|
|
throw new Error(`API error: ${response.status} - ${response.statusText}`); |
|
|
} |
|
|
|
|
|
const reader = response.body.getReader(); |
|
|
const decoder = new TextDecoder(); |
|
|
let buffer = ''; |
|
|
|
|
|
while (true) { |
|
|
const { done, value } = await reader.read(); |
|
|
if (done) break; |
|
|
|
|
|
buffer += decoder.decode(value, { stream: true }); |
|
|
const lines = buffer.split('\n'); |
|
|
|
|
|
buffer = lines.pop() || ''; |
|
|
|
|
|
for (const line of lines) { |
|
|
if (line.trim() === '') continue; |
|
|
|
|
|
if (line.startsWith('data: ') && line !== 'data: [DONE]') { |
|
|
try { |
|
|
const jsonData = line.slice(6); |
|
|
if (jsonData.trim()) { |
|
|
const data = JSON.parse(jsonData); |
|
|
|
|
|
if (data.token && data.token.text) { |
|
|
onChunk(data.token.text); |
|
|
} else if (data.generated_text) { |
|
|
onChunk(data.generated_text); |
|
|
} else if (data[0] && data[0].generated_text) { |
|
|
onChunk(data[0].generated_text); |
|
|
} |
|
|
} |
|
|
} catch (e) { |
|
|
console.log('Skipping invalid JSON line:', line); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
onComplete(); |
|
|
} catch (error) { |
|
|
console.error('Stream error:', error); |
|
|
onError(error.message); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
async streamChatCompletionAlt(messages, modelConfig, onChunk, onComplete, onError) { |
|
|
try { |
|
|
console.log('Using chat completion format with model:', modelConfig.endpoint); |
|
|
|
|
|
const response = await fetch( |
|
|
'https://api-inference.huggingface.co/models/' + modelConfig.endpoint, |
|
|
{ |
|
|
method: 'POST', |
|
|
headers: { |
|
|
'Authorization': `Bearer ${this.token}`, |
|
|
'Content-Type': 'application/json', |
|
|
}, |
|
|
body: JSON.stringify({ |
|
|
inputs: this.formatChatPrompt(messages), |
|
|
parameters: { |
|
|
max_new_tokens: 1024, |
|
|
temperature: 0.7, |
|
|
top_p: 0.9, |
|
|
do_sample: true, |
|
|
return_full_text: false |
|
|
}, |
|
|
options: { |
|
|
wait_for_model: true, |
|
|
use_cache: false |
|
|
}, |
|
|
stream: true |
|
|
}) |
|
|
} |
|
|
); |
|
|
|
|
|
if (!response.ok) { |
|
|
const errorText = await response.text(); |
|
|
console.error('API Error:', response.status, errorText); |
|
|
throw new Error(`API error: ${response.status} - ${errorText}`); |
|
|
} |
|
|
|
|
|
const reader = response.body.getReader(); |
|
|
const decoder = new TextDecoder(); |
|
|
let buffer = ''; |
|
|
let accumulatedText = ''; |
|
|
|
|
|
while (true) { |
|
|
const { done, value } = await reader.read(); |
|
|
if (done) break; |
|
|
|
|
|
buffer += decoder.decode(value, { stream: true }); |
|
|
const lines = buffer.split('\n'); |
|
|
|
|
|
buffer = lines.pop() || ''; |
|
|
|
|
|
for (const line of lines) { |
|
|
if (line.trim() === '') continue; |
|
|
|
|
|
if (line.startsWith('data: ') && line !== 'data: [DONE]') { |
|
|
try { |
|
|
const jsonData = line.slice(6); |
|
|
if (jsonData.trim()) { |
|
|
const data = JSON.parse(jsonData); |
|
|
|
|
|
|
|
|
let newText = ''; |
|
|
if (data.token && data.token.text) { |
|
|
newText = data.token.text; |
|
|
} else if (data.generated_text) { |
|
|
newText = data.generated_text.replace(accumulatedText, ''); |
|
|
} else if (data[0] && data[0].generated_text) { |
|
|
newText = data[0].generated_text.replace(accumulatedText, ''); |
|
|
} |
|
|
|
|
|
if (newText) { |
|
|
accumulatedText += newText; |
|
|
onChunk(newText); |
|
|
} |
|
|
} |
|
|
} catch (e) { |
|
|
console.log('Skipping invalid JSON line:', line); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
onComplete(); |
|
|
} catch (error) { |
|
|
console.error('Stream error:', error); |
|
|
onError(error.message); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
formatMessagesForInference(messages) { |
|
|
if (messages.length === 0) return ''; |
|
|
|
|
|
|
|
|
if (messages.length === 1) { |
|
|
return messages[0].content; |
|
|
} |
|
|
|
|
|
|
|
|
let conversation = ''; |
|
|
for (const msg of messages) { |
|
|
const role = msg.role === 'user' ? 'Human' : 'Assistant'; |
|
|
conversation += `${role}: ${msg.content}\n`; |
|
|
} |
|
|
conversation += 'Assistant: '; |
|
|
|
|
|
return conversation; |
|
|
} |
|
|
|
|
|
|
|
|
formatChatPrompt(messages) { |
|
|
if (messages.length === 0) return ''; |
|
|
|
|
|
const lastMessage = messages[messages.length - 1]; |
|
|
return lastMessage.content; |
|
|
} |
|
|
} |