kaiiddo commited on
Commit
60866c1
·
verified ·
1 Parent(s): 7940110

Update src/services/huggingfaceService.js

Browse files
Files changed (1) hide show
  1. src/services/huggingfaceService.js +98 -20
src/services/huggingfaceService.js CHANGED
@@ -6,8 +6,11 @@ export class HuggingFaceService {
6
 
7
  async streamChatCompletion(messages, modelConfig, onChunk, onComplete, onError) {
8
  try {
 
 
 
9
  const response = await fetch(
10
- `https://api-inference.huggingface.co/models/${modelConfig.endpoint}`,
11
  {
12
  method: 'POST',
13
  headers: {
@@ -15,20 +18,27 @@ export class HuggingFaceService {
15
  'Content-Type': 'application/json',
16
  },
17
  body: JSON.stringify({
18
- inputs: messages[messages.length - 1].content,
19
  parameters: {
20
  max_new_tokens: 1024,
21
  temperature: 0.7,
22
  top_p: 0.9,
 
23
  return_full_text: false
24
  },
 
 
 
 
25
  stream: true
26
  })
27
  }
28
  );
29
 
30
  if (!response.ok) {
31
- throw new Error(`API error: ${response.status}`);
 
 
32
  }
33
 
34
  const reader = response.body.getReader();
@@ -45,14 +55,24 @@ export class HuggingFaceService {
45
  buffer = lines.pop() || '';
46
 
47
  for (const line of lines) {
 
 
48
  if (line.startsWith('data: ') && line !== 'data: [DONE]') {
49
  try {
50
- const data = JSON.parse(line.slice(6));
51
- if (data.token && data.token.text) {
52
- onChunk(data.token.text);
 
 
 
 
 
 
 
 
53
  }
54
  } catch (e) {
55
- // Skip invalid JSON
56
  }
57
  }
58
  }
@@ -60,6 +80,7 @@ export class HuggingFaceService {
60
 
61
  onComplete();
62
  } catch (error) {
 
63
  onError(error.message);
64
  }
65
  }
@@ -67,9 +88,10 @@ export class HuggingFaceService {
67
  // Alternative method using chat completion format
68
  async streamChatCompletionAlt(messages, modelConfig, onChunk, onComplete, onError) {
69
  try {
70
- // Using the provider-based endpoint
 
71
  const response = await fetch(
72
- 'https://api-inference.huggingface.co/chat/completions',
73
  {
74
  method: 'POST',
75
  headers: {
@@ -77,23 +99,33 @@ export class HuggingFaceService {
77
  'Content-Type': 'application/json',
78
  },
79
  body: JSON.stringify({
80
- model: modelConfig.endpoint,
81
- messages: messages,
82
- stream: true,
83
- max_tokens: 1024,
84
- temperature: 0.7
 
 
 
 
 
 
 
 
85
  })
86
  }
87
  );
88
 
89
  if (!response.ok) {
90
- const errorData = await response.text();
91
- throw new Error(`API error: ${response.status} - ${errorData}`);
 
92
  }
93
 
94
  const reader = response.body.getReader();
95
  const decoder = new TextDecoder();
96
  let buffer = '';
 
97
 
98
  while (true) {
99
  const { done, value } = await reader.read();
@@ -105,14 +137,31 @@ export class HuggingFaceService {
105
  buffer = lines.pop() || '';
106
 
107
  for (const line of lines) {
 
 
108
  if (line.startsWith('data: ') && line !== 'data: [DONE]') {
109
  try {
110
- const data = JSON.parse(line.slice(6));
111
- if (data.choices && data.choices[0] && data.choices[0].delta && data.choices[0].delta.content) {
112
- onChunk(data.choices[0].delta.content);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  }
114
  } catch (e) {
115
- // Skip invalid JSON
116
  }
117
  }
118
  }
@@ -120,7 +169,36 @@ export class HuggingFaceService {
120
 
121
  onComplete();
122
  } catch (error) {
 
123
  onError(error.message);
124
  }
125
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  }
 
6
 
7
  async streamChatCompletion(messages, modelConfig, onChunk, onComplete, onError) {
8
  try {
9
+ console.log('Starting chat completion with model:', modelConfig.endpoint);
10
+
11
+ // Use the chat completions endpoint which is more reliable
12
  const response = await fetch(
13
+ 'https://api-inference.huggingface.co/models/' + modelConfig.endpoint,
14
  {
15
  method: 'POST',
16
  headers: {
 
18
  'Content-Type': 'application/json',
19
  },
20
  body: JSON.stringify({
21
+ inputs: this.formatMessagesForInference(messages),
22
  parameters: {
23
  max_new_tokens: 1024,
24
  temperature: 0.7,
25
  top_p: 0.9,
26
+ do_sample: true,
27
  return_full_text: false
28
  },
29
+ options: {
30
+ wait_for_model: true,
31
+ use_cache: false
32
+ },
33
  stream: true
34
  })
35
  }
36
  );
37
 
38
  if (!response.ok) {
39
+ const errorText = await response.text();
40
+ console.error('API Error:', response.status, errorText);
41
+ throw new Error(`API error: ${response.status} - ${response.statusText}`);
42
  }
43
 
44
  const reader = response.body.getReader();
 
55
  buffer = lines.pop() || '';
56
 
57
  for (const line of lines) {
58
+ if (line.trim() === '') continue;
59
+
60
  if (line.startsWith('data: ') && line !== 'data: [DONE]') {
61
  try {
62
+ const jsonData = line.slice(6);
63
+ if (jsonData.trim()) {
64
+ const data = JSON.parse(jsonData);
65
+ // Handle different response formats
66
+ if (data.token && data.token.text) {
67
+ onChunk(data.token.text);
68
+ } else if (data.generated_text) {
69
+ onChunk(data.generated_text);
70
+ } else if (data[0] && data[0].generated_text) {
71
+ onChunk(data[0].generated_text);
72
+ }
73
  }
74
  } catch (e) {
75
+ console.log('Skipping invalid JSON line:', line);
76
  }
77
  }
78
  }
 
80
 
81
  onComplete();
82
  } catch (error) {
83
+ console.error('Stream error:', error);
84
  onError(error.message);
85
  }
86
  }
 
88
  // Alternative method using chat completion format
89
  async streamChatCompletionAlt(messages, modelConfig, onChunk, onComplete, onError) {
90
  try {
91
+ console.log('Using chat completion format with model:', modelConfig.endpoint);
92
+
93
  const response = await fetch(
94
+ 'https://api-inference.huggingface.co/models/' + modelConfig.endpoint,
95
  {
96
  method: 'POST',
97
  headers: {
 
99
  'Content-Type': 'application/json',
100
  },
101
  body: JSON.stringify({
102
+ inputs: this.formatChatPrompt(messages),
103
+ parameters: {
104
+ max_new_tokens: 1024,
105
+ temperature: 0.7,
106
+ top_p: 0.9,
107
+ do_sample: true,
108
+ return_full_text: false
109
+ },
110
+ options: {
111
+ wait_for_model: true,
112
+ use_cache: false
113
+ },
114
+ stream: true
115
  })
116
  }
117
  );
118
 
119
  if (!response.ok) {
120
+ const errorText = await response.text();
121
+ console.error('API Error:', response.status, errorText);
122
+ throw new Error(`API error: ${response.status} - ${errorText}`);
123
  }
124
 
125
  const reader = response.body.getReader();
126
  const decoder = new TextDecoder();
127
  let buffer = '';
128
+ let accumulatedText = '';
129
 
130
  while (true) {
131
  const { done, value } = await reader.read();
 
137
  buffer = lines.pop() || '';
138
 
139
  for (const line of lines) {
140
+ if (line.trim() === '') continue;
141
+
142
  if (line.startsWith('data: ') && line !== 'data: [DONE]') {
143
  try {
144
+ const jsonData = line.slice(6);
145
+ if (jsonData.trim()) {
146
+ const data = JSON.parse(jsonData);
147
+
148
+ // Extract text from different possible response formats
149
+ let newText = '';
150
+ if (data.token && data.token.text) {
151
+ newText = data.token.text;
152
+ } else if (data.generated_text) {
153
+ newText = data.generated_text.replace(accumulatedText, '');
154
+ } else if (data[0] && data[0].generated_text) {
155
+ newText = data[0].generated_text.replace(accumulatedText, '');
156
+ }
157
+
158
+ if (newText) {
159
+ accumulatedText += newText;
160
+ onChunk(newText);
161
+ }
162
  }
163
  } catch (e) {
164
+ console.log('Skipping invalid JSON line:', line);
165
  }
166
  }
167
  }
 
169
 
170
  onComplete();
171
  } catch (error) {
172
+ console.error('Stream error:', error);
173
  onError(error.message);
174
  }
175
  }
176
+
177
+ // Format messages for inference API
178
+ formatMessagesForInference(messages) {
179
+ if (messages.length === 0) return '';
180
+
181
+ // For single message, just return the content
182
+ if (messages.length === 1) {
183
+ return messages[0].content;
184
+ }
185
+
186
+ // For multiple messages, format as conversation
187
+ let conversation = '';
188
+ for (const msg of messages) {
189
+ const role = msg.role === 'user' ? 'Human' : 'Assistant';
190
+ conversation += `${role}: ${msg.content}\n`;
191
+ }
192
+ conversation += 'Assistant: ';
193
+
194
+ return conversation;
195
+ }
196
+
197
+ // Format chat prompt
198
+ formatChatPrompt(messages) {
199
+ if (messages.length === 0) return '';
200
+
201
+ const lastMessage = messages[messages.length - 1];
202
+ return lastMessage.content;
203
+ }
204
  }