qyle commited on
Commit
392b300
·
verified ·
1 Parent(s): e43b823

deployment

Browse files
champ/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (165 Bytes). View file
 
champ/__pycache__/agent.cpython-313.pyc ADDED
Binary file (3.32 kB). View file
 
champ/__pycache__/prompts.cpython-313.pyc ADDED
Binary file (8.02 kB). View file
 
champ/__pycache__/rag.cpython-313.pyc ADDED
Binary file (1.22 kB). View file
 
champ/__pycache__/service.cpython-313.pyc ADDED
Binary file (2.86 kB). View file
 
champ/__pycache__/triage.cpython-313.pyc ADDED
Binary file (4.42 kB). View file
 
dynamodb_helper.py CHANGED
@@ -6,6 +6,9 @@ from botocore.exceptions import ClientError
6
  from datetime import datetime, timezone
7
  from uuid import uuid4
8
  from decimal import Decimal
 
 
 
9
 
10
  AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
11
  AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY", None)
@@ -14,25 +17,29 @@ DYNAMODB_ENDPOINT = os.getenv("DYNAMODB_ENDPOINT", None)
14
  DDB_TABLE = os.getenv("DDB_TABLE", "chatbot-conversations")
15
  USE_LOCAL_DDB = os.getenv("USE_LOCAL_DDB", "false").lower() == "true"
16
 
 
17
  def get_dynamodb_client():
18
- if USE_LOCAL_DDB: # only for local testing with DynamoDB Local
19
  return boto3.resource(
20
  "dynamodb",
21
  endpoint_url=DYNAMODB_ENDPOINT,
22
  region_name=AWS_REGION,
23
  aws_access_key_id="fake",
24
- aws_secret_access_key="fake"
 
 
 
 
 
 
 
25
  )
26
- else: # production AWS DynamoDB
27
- return boto3.resource("dynamodb",
28
- region_name=AWS_REGION,
29
- aws_access_key_id=AWS_ACCESS_KEY,
30
- aws_secret_access_key=AWS_SECRET_ACCESS_KEY
31
- )
32
 
33
  dynamodb = get_dynamodb_client()
34
  table = None
35
 
 
36
  def create_table_if_not_exists(dynamodb):
37
  global table
38
  client = dynamodb.meta.client
@@ -55,29 +62,29 @@ def create_table_if_not_exists(dynamodb):
55
  TableName=DDB_TABLE,
56
  KeySchema=[
57
  {"AttributeName": "PK", "KeyType": "HASH"},
58
- {"AttributeName": "SK", "KeyType": "RANGE"}
59
  ],
60
  AttributeDefinitions=[
61
  {"AttributeName": "PK", "AttributeType": "S"},
62
  {"AttributeName": "SK", "AttributeType": "S"},
63
  {"AttributeName": "GSI1_PK", "AttributeType": "S"},
64
- {"AttributeName": "GSI1_SK", "AttributeType": "S"}
65
  ],
66
  GlobalSecondaryIndexes=[
67
  {
68
  "IndexName": "GSI1",
69
  "KeySchema": [
70
  {"AttributeName": "GSI1_PK", "KeyType": "HASH"},
71
- {"AttributeName": "GSI1_SK", "KeyType": "RANGE"}
72
  ],
73
  "Projection": {"ProjectionType": "ALL"},
74
  "ProvisionedThroughput": {
75
  "ReadCapacityUnits": 5,
76
- "WriteCapacityUnits": 5
77
  },
78
  }
79
  ],
80
- BillingMode='PAY_PER_REQUEST'
81
  # ProvisionedThroughput={
82
  # "ReadCapacityUnits": 5,
83
  # "WriteCapacityUnits": 5
@@ -97,8 +104,10 @@ def iso_ts():
97
  # Return the current timestamp in ISO 8601 format
98
  return datetime.now(timezone.utc).isoformat()
99
 
 
100
  table = create_table_if_not_exists(dynamodb)
101
 
 
102
  def convert_floats(obj):
103
  if isinstance(obj, float):
104
  return Decimal(str(obj))
@@ -109,6 +118,7 @@ def convert_floats(obj):
109
  else:
110
  return obj
111
 
 
112
  def log_event(user_id, session_id, data):
113
  """
114
  Log conversation data to DynamoDB table.
@@ -125,12 +135,12 @@ def log_event(user_id, session_id, data):
125
  item = {
126
  "PK": f"SESSION#{session_id}",
127
  "SK": f"TS#{ts}#{uuid4().hex}",
128
- 'user_id': user_id,
129
  "GSI1_PK": f"USER#{user_id}",
130
  "GSI1_SK": f"TS#{ts}",
131
- 'session_id': session_id,
132
- 'timestamp': ts,
133
- 'data': convert_floats(data)
134
  }
135
  print(f"Logging conversation: {item}")
136
  try:
 
6
  from datetime import datetime, timezone
7
  from uuid import uuid4
8
  from decimal import Decimal
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv()
12
 
13
  AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
14
  AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY", None)
 
17
  DDB_TABLE = os.getenv("DDB_TABLE", "chatbot-conversations")
18
  USE_LOCAL_DDB = os.getenv("USE_LOCAL_DDB", "false").lower() == "true"
19
 
20
+
21
  def get_dynamodb_client():
22
+ if USE_LOCAL_DDB: # only for local testing with DynamoDB Local
23
  return boto3.resource(
24
  "dynamodb",
25
  endpoint_url=DYNAMODB_ENDPOINT,
26
  region_name=AWS_REGION,
27
  aws_access_key_id="fake",
28
+ aws_secret_access_key="fake",
29
+ )
30
+ else: # production AWS DynamoDB
31
+ return boto3.resource(
32
+ "dynamodb",
33
+ region_name=AWS_REGION,
34
+ aws_access_key_id=AWS_ACCESS_KEY,
35
+ aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
36
  )
37
+
 
 
 
 
 
38
 
39
  dynamodb = get_dynamodb_client()
40
  table = None
41
 
42
+
43
  def create_table_if_not_exists(dynamodb):
44
  global table
45
  client = dynamodb.meta.client
 
62
  TableName=DDB_TABLE,
63
  KeySchema=[
64
  {"AttributeName": "PK", "KeyType": "HASH"},
65
+ {"AttributeName": "SK", "KeyType": "RANGE"},
66
  ],
67
  AttributeDefinitions=[
68
  {"AttributeName": "PK", "AttributeType": "S"},
69
  {"AttributeName": "SK", "AttributeType": "S"},
70
  {"AttributeName": "GSI1_PK", "AttributeType": "S"},
71
+ {"AttributeName": "GSI1_SK", "AttributeType": "S"},
72
  ],
73
  GlobalSecondaryIndexes=[
74
  {
75
  "IndexName": "GSI1",
76
  "KeySchema": [
77
  {"AttributeName": "GSI1_PK", "KeyType": "HASH"},
78
+ {"AttributeName": "GSI1_SK", "KeyType": "RANGE"},
79
  ],
80
  "Projection": {"ProjectionType": "ALL"},
81
  "ProvisionedThroughput": {
82
  "ReadCapacityUnits": 5,
83
+ "WriteCapacityUnits": 5,
84
  },
85
  }
86
  ],
87
+ BillingMode="PAY_PER_REQUEST",
88
  # ProvisionedThroughput={
89
  # "ReadCapacityUnits": 5,
90
  # "WriteCapacityUnits": 5
 
104
  # Return the current timestamp in ISO 8601 format
105
  return datetime.now(timezone.utc).isoformat()
106
 
107
+
108
  table = create_table_if_not_exists(dynamodb)
109
 
110
+
111
  def convert_floats(obj):
112
  if isinstance(obj, float):
113
  return Decimal(str(obj))
 
118
  else:
119
  return obj
120
 
121
+
122
  def log_event(user_id, session_id, data):
123
  """
124
  Log conversation data to DynamoDB table.
 
135
  item = {
136
  "PK": f"SESSION#{session_id}",
137
  "SK": f"TS#{ts}#{uuid4().hex}",
138
+ "user_id": user_id,
139
  "GSI1_PK": f"USER#{user_id}",
140
  "GSI1_SK": f"TS#{ts}",
141
+ "session_id": session_id,
142
+ "timestamp": ts,
143
+ "data": convert_floats(data),
144
  }
145
  print(f"Logging conversation: {item}")
146
  try:
main.py CHANGED
@@ -4,14 +4,12 @@ from contextlib import asynccontextmanager
4
 
5
  from pathlib import Path
6
 
7
- from typing import List, Literal, Optional, Tuple, Dict, Any
8
- from datetime import datetime, timezone
9
 
10
  from dotenv import load_dotenv
11
- load_dotenv()
12
 
13
  from fastapi import FastAPI, Request, BackgroundTasks
14
- from fastapi.responses import HTMLResponse, JSONResponse
15
  from fastapi.staticfiles import StaticFiles
16
  from fastapi.templating import Jinja2Templates
17
 
@@ -19,7 +17,7 @@ from pydantic import BaseModel
19
  from dynamodb_helper import log_event
20
 
21
  from huggingface_hub import InferenceClient
22
- from openai import OpenAI
23
  from google import genai
24
 
25
 
@@ -28,13 +26,15 @@ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
28
  from champ.prompts import DEFAULT_SYSTEM_PROMPT
29
  from champ.service import ChampService
30
 
 
 
31
  # -------------------- Config --------------------
32
  BASE_DIR = Path(__file__).resolve().parent
33
 
34
  MODEL_MAP = {
35
  "champ": "champ-model/placeholder",
36
  "openai": "gpt-5-nano-2025-08-07",
37
- "google": "gemini-2.5-flash-lite"
38
  }
39
 
40
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HF_API_TOKEN")
@@ -57,13 +57,13 @@ if GEMINI_API_KEY is None:
57
  )
58
 
59
  hf_client = InferenceClient(token=HF_TOKEN)
60
- openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
61
  gemini_client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None
62
 
63
 
64
-
65
  # Max history messages to keep for context
66
- MAX_HISTORY = 20
 
67
 
68
  class ChatMessage(BaseModel):
69
  role: Literal["user", "assistant", "system"]
@@ -105,15 +105,22 @@ def convert_messages_langchain(messages: List[ChatMessage]):
105
  list_chatmessages.append(SystemMessage(content=m.content))
106
  return list_chatmessages
107
 
 
108
  champ = ChampService(base_dir=BASE_DIR, hf_token=HF_TOKEN)
109
 
110
- def _call_openai(model_id: str, msgs: list[dict], temperature: float) -> str:
111
- resp = openai_client.responses.create(
112
- model=model_id,
113
- input=msgs,
114
- # no temperature for GPT-5 reasoning models
 
 
115
  )
116
- return (resp.output_text or "").strip()
 
 
 
 
117
 
118
  def _call_gemini(model_id: str, msgs: list[dict], temperature: float) -> str:
119
  transcript = []
@@ -130,7 +137,12 @@ def _call_gemini(model_id: str, msgs: list[dict], temperature: float) -> str:
130
  )
131
  return (resp.text or "").strip()
132
 
133
- def _call_hf_client(model_id: str, msgs: list[dict], temperature: float,) -> str:
 
 
 
 
 
134
  resp = hf_client.chat.completions.create(
135
  model=model_id,
136
  messages=msgs,
@@ -142,7 +154,10 @@ def _call_hf_client(model_id: str, msgs: list[dict], temperature: float,) -> str
142
  except Exception:
143
  return str(resp)
144
 
145
- def call_llm(req: ChatRequest) -> Tuple[str, Dict[str, Any]]:
 
 
 
146
  if req.model_type == "champ":
147
  msgs = convert_messages_langchain(req.messages)
148
  reply, triage_meta = champ.invoke(msgs)
@@ -155,7 +170,7 @@ def call_llm(req: ChatRequest) -> Tuple[str, Dict[str, Any]]:
155
  msgs = convert_messages(req.messages)
156
 
157
  if req.model_type == "openai":
158
- return _call_openai(model_id, msgs, req.temperature), {}
159
 
160
  if req.model_type == "google":
161
  return _call_gemini(model_id, msgs, req.temperature), {}
@@ -173,6 +188,7 @@ def call_llm(req: ChatRequest) -> Tuple[str, Dict[str, Any]]:
173
  # }
174
  # conversations_collection.insert_one(record)
175
 
 
176
  # -------------------- FastAPI setup --------------------
177
  @asynccontextmanager
178
  async def lifespan(app: FastAPI):
@@ -180,6 +196,7 @@ async def lifespan(app: FastAPI):
180
  print("CHAMP RAG + agent initialized.")
181
  yield
182
 
 
183
  app = FastAPI(lifespan=lifespan)
184
  app.mount("/static", StaticFiles(directory="static"), name="static")
185
  templates = Jinja2Templates(directory="templates")
@@ -197,7 +214,34 @@ async def chat_endpoint(payload: ChatRequest, background_tasks: BackgroundTasks)
197
 
198
  try:
199
  loop = asyncio.get_running_loop()
200
- reply, triage_meta = await loop.run_in_executor(None, call_llm, payload)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  except Exception as e:
202
  background_tasks.add_task(
203
  log_event,
@@ -211,7 +255,6 @@ async def chat_endpoint(payload: ChatRequest, background_tasks: BackgroundTasks)
211
  "messages": payload.messages[-1].dict(),
212
  },
213
  )
214
- return JSONResponse({"error": str(e)}, status_code=500)
215
 
216
  background_tasks.add_task(
217
  log_event,
@@ -226,4 +269,4 @@ async def chat_endpoint(payload: ChatRequest, background_tasks: BackgroundTasks)
226
  **(triage_meta or {}),
227
  },
228
  )
229
- return {"reply": reply}
 
4
 
5
  from pathlib import Path
6
 
7
+ from typing import AsyncGenerator, List, Literal, Optional, Tuple, Dict, Any, Generator
 
8
 
9
  from dotenv import load_dotenv
 
10
 
11
  from fastapi import FastAPI, Request, BackgroundTasks
12
+ from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
13
  from fastapi.staticfiles import StaticFiles
14
  from fastapi.templating import Jinja2Templates
15
 
 
17
  from dynamodb_helper import log_event
18
 
19
  from huggingface_hub import InferenceClient
20
+ from openai import AsyncOpenAI
21
  from google import genai
22
 
23
 
 
26
  from champ.prompts import DEFAULT_SYSTEM_PROMPT
27
  from champ.service import ChampService
28
 
29
+ load_dotenv()
30
+
31
  # -------------------- Config --------------------
32
  BASE_DIR = Path(__file__).resolve().parent
33
 
34
  MODEL_MAP = {
35
  "champ": "champ-model/placeholder",
36
  "openai": "gpt-5-nano-2025-08-07",
37
+ "google": "gemini-2.5-flash-lite",
38
  }
39
 
40
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HF_API_TOKEN")
 
57
  )
58
 
59
  hf_client = InferenceClient(token=HF_TOKEN)
60
+ openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
61
  gemini_client = genai.Client(api_key=GEMINI_API_KEY) if GEMINI_API_KEY else None
62
 
63
 
 
64
  # Max history messages to keep for context
65
+ MAX_HISTORY = 20
66
+
67
 
68
  class ChatMessage(BaseModel):
69
  role: Literal["user", "assistant", "system"]
 
105
  list_chatmessages.append(SystemMessage(content=m.content))
106
  return list_chatmessages
107
 
108
+
109
  champ = ChampService(base_dir=BASE_DIR, hf_token=HF_TOKEN)
110
 
111
+
112
+ async def _call_openai(
113
+ model_id: str, msgs: list[dict], temperature: float
114
+ ) -> AsyncGenerator[str, None]:
115
+ # We are streaming the output because the model answers tend to be very long and slow to generate
116
+ stream = await openai_client.responses.create(
117
+ model=model_id, input=msgs, stream=True
118
  )
119
+
120
+ async for chunk in stream:
121
+ if chunk.type == "response.output_text.delta":
122
+ yield chunk.delta
123
+
124
 
125
  def _call_gemini(model_id: str, msgs: list[dict], temperature: float) -> str:
126
  transcript = []
 
137
  )
138
  return (resp.text or "").strip()
139
 
140
+
141
+ def _call_hf_client(
142
+ model_id: str,
143
+ msgs: list[dict],
144
+ temperature: float,
145
+ ) -> str:
146
  resp = hf_client.chat.completions.create(
147
  model=model_id,
148
  messages=msgs,
 
154
  except Exception:
155
  return str(resp)
156
 
157
+
158
+ def call_llm(
159
+ req: ChatRequest,
160
+ ) -> AsyncGenerator[str, None] | Tuple[str, Dict[str, Any]]:
161
  if req.model_type == "champ":
162
  msgs = convert_messages_langchain(req.messages)
163
  reply, triage_meta = champ.invoke(msgs)
 
170
  msgs = convert_messages(req.messages)
171
 
172
  if req.model_type == "openai":
173
+ return _call_openai(model_id, msgs, req.temperature)
174
 
175
  if req.model_type == "google":
176
  return _call_gemini(model_id, msgs, req.temperature), {}
 
188
  # }
189
  # conversations_collection.insert_one(record)
190
 
191
+
192
  # -------------------- FastAPI setup --------------------
193
  @asynccontextmanager
194
  async def lifespan(app: FastAPI):
 
196
  print("CHAMP RAG + agent initialized.")
197
  yield
198
 
199
+
200
  app = FastAPI(lifespan=lifespan)
201
  app.mount("/static", StaticFiles(directory="static"), name="static")
202
  templates = Jinja2Templates(directory="templates")
 
214
 
215
  try:
216
  loop = asyncio.get_running_loop()
217
+ result = await loop.run_in_executor(None, call_llm, payload)
218
+
219
+ if isinstance(result, AsyncGenerator):
220
+
221
+ async def logging_wrapper():
222
+ reply = ""
223
+ async for token in result:
224
+ reply += token
225
+ yield token
226
+
227
+ background_tasks.add_task(
228
+ log_event,
229
+ user_id=payload.user_id,
230
+ session_id=payload.session_id,
231
+ data={
232
+ "model_type": payload.model_type,
233
+ "consent": payload.consent,
234
+ "temperature": payload.temperature,
235
+ "messages": payload.messages[-1].dict(),
236
+ "reply": reply,
237
+ "triage_meta": {},
238
+ },
239
+ )
240
+
241
+ return StreamingResponse(logging_wrapper(), media_type="text/event-stream")
242
+
243
+ reply, triage_meta = result
244
+
245
  except Exception as e:
246
  background_tasks.add_task(
247
  log_event,
 
255
  "messages": payload.messages[-1].dict(),
256
  },
257
  )
 
258
 
259
  background_tasks.add_task(
260
  log_event,
 
269
  **(triage_meta or {}),
270
  },
271
  )
272
+ return {"reply": reply}
requirements.txt CHANGED
@@ -124,3 +124,4 @@ websockets==15.0.1
124
  xxhash==3.6.0
125
  yarl==1.22.0
126
  zstandard==0.25.0
 
 
124
  xxhash==3.6.0
125
  yarl==1.22.0
126
  zstandard==0.25.0
127
+ pytz==2025.2
static/app.js CHANGED
@@ -15,14 +15,21 @@ const consentCheckbox = document.getElementById('consentCheckbox');
15
  const consentBtn = document.getElementById('consentBtn');
16
 
17
  // Local in-browser chat history
18
- let messages = [];
 
 
 
 
 
 
19
  let consentGranted = false;
20
  let sessionId = 'session-' + crypto.randomUUID(); // Unique session ID, generated once per page load
21
  document.body.classList.add('no-scroll');
22
 
23
  function renderMessages() {
24
  chatWindow.innerHTML = '';
25
- messages.forEach((m) => {
 
26
  const bubble = document.createElement('div');
27
  bubble.classList.add(
28
  'msg-bubble',
@@ -58,7 +65,8 @@ async function sendMessage() {
58
  if (!text) return;
59
 
60
  // Add user message locally
61
- messages.push({ role: 'user', content: text });
 
62
  renderMessages();
63
  userInput.value = '';
64
 
@@ -68,12 +76,11 @@ async function sendMessage() {
68
  const temperature = parseFloat(tempSlider.value);
69
  // const maxTokens = parseInt(maxTokensSlider.value, 10);
70
  // const systemPrompt = systemPresetSelect.value;
71
- const modelType = systemPresetSelect.value;
72
 
73
  const payload = {
74
  user_id: getMachineId(),
75
  session_id: sessionId,
76
- messages: messages.map((m) => ({ role: m.role, content: m.content })),
77
  temperature,
78
  // max_new_tokens: maxTokens,
79
  model_type: modelType,
@@ -87,17 +94,39 @@ async function sendMessage() {
87
  body: JSON.stringify(payload),
88
  });
89
 
90
- const data = await res.json();
91
-
92
  if (!res.ok) {
93
  statusEl.textContent = data.error || 'Error from server.';
94
  statusEl.className = 'status status-error';
95
  return;
96
  }
97
 
98
- const reply = data.reply || '(No reply)';
99
- messages.push({ role: 'assistant', content: reply });
100
- renderMessages();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  statusEl.textContent = 'Ready';
103
  statusEl.className = 'status status-ok';
@@ -113,8 +142,16 @@ function resetSession() {
113
  }
114
 
115
  function clearConversation() {
 
 
116
  resetSession();
117
- messages = [];
 
 
 
 
 
 
118
  renderMessages();
119
  statusEl.textContent = 'Conversation cleared. Start a new chat!';
120
  statusEl.className = 'status status-ok';
@@ -147,15 +184,19 @@ userInput.addEventListener('keydown', (e) => {
147
  });
148
 
149
  tempSlider.addEventListener('input', () => {
150
- if (!tempSlider.disabled) updateSlidersUI();
 
 
 
 
151
  });
152
  // maxTokensSlider.addEventListener("input", updateSlidersUI);
153
  clearBtn.addEventListener('click', clearConversation);
154
 
155
  systemPresetSelect.addEventListener('change', () => {
156
  updateTempControlForModel(); // 👈 add this
157
- clearConversation();
158
- statusEl.textContent = 'Model changed. History cleared.';
159
  statusEl.className = 'status status-ok';
160
  });
161
 
@@ -179,6 +220,7 @@ function updateTempControlForModel() {
179
  // Enable slider for other models
180
  tempSlider.disabled = false;
181
  tempSlider.classList.remove('disabled');
 
182
  updateSlidersUI(); // refresh displayed value
183
  }
184
  }
 
15
  const consentBtn = document.getElementById('consentBtn');
16
 
17
  // Local in-browser chat history
18
+ // We store for each model its chat history.
19
+ // We store the temperature of the google model as it can change.
20
+ const modelChats = {};
21
+ modelChats["champ"] = {"messages": []}
22
+ modelChats["openai"] = {"messages": []}
23
+ modelChats["google"] = {"messages": [], "temperature": 0.2}
24
+
25
  let consentGranted = false;
26
  let sessionId = 'session-' + crypto.randomUUID(); // Unique session ID, generated once per page load
27
  document.body.classList.add('no-scroll');
28
 
29
  function renderMessages() {
30
  chatWindow.innerHTML = '';
31
+ const modelType = systemPresetSelect.value;
32
+ modelChats[modelType]["messages"].forEach((m) => {
33
  const bubble = document.createElement('div');
34
  bubble.classList.add(
35
  'msg-bubble',
 
65
  if (!text) return;
66
 
67
  // Add user message locally
68
+ const modelType = systemPresetSelect.value;
69
+ modelChats[modelType]["messages"].push({ role: 'user', content: text });
70
  renderMessages();
71
  userInput.value = '';
72
 
 
76
  const temperature = parseFloat(tempSlider.value);
77
  // const maxTokens = parseInt(maxTokensSlider.value, 10);
78
  // const systemPrompt = systemPresetSelect.value;
 
79
 
80
  const payload = {
81
  user_id: getMachineId(),
82
  session_id: sessionId,
83
+ messages: modelChats[modelType]["messages"].map((m) => ({ role: m.role, content: m.content })),
84
  temperature,
85
  // max_new_tokens: maxTokens,
86
  model_type: modelType,
 
94
  body: JSON.stringify(payload),
95
  });
96
 
 
 
97
  if (!res.ok) {
98
  statusEl.textContent = data.error || 'Error from server.';
99
  statusEl.className = 'status status-error';
100
  return;
101
  }
102
 
103
+ const contentType = res.headers.get('content-type');
104
+
105
+ if (contentType && contentType.includes('application/json')) {
106
+ // Batch response
107
+ const data = await res.json();
108
+
109
+ const reply = data.reply || '(No reply)';
110
+ modelChats[modelType]["messages"].push({ role: 'assistant', content: reply });
111
+ renderMessages();
112
+ } else {
113
+ // Streaming response
114
+ const assistantMessage = { role: 'assistant', content: '' };
115
+ modelChats[modelType]["messages"].push(assistantMessage);
116
+
117
+ const reader = res.body.getReader();
118
+ const decoder = new TextDecoder();
119
+ let done = false;
120
+
121
+ while (!done) {
122
+ const { value, done: readerDone } = await reader.read();
123
+ done = readerDone;
124
+ const chunk = decoder.decode(value, { stream: true });
125
+ assistantMessage.content += chunk;
126
+ renderMessages();
127
+ }
128
+ }
129
+
130
 
131
  statusEl.textContent = 'Ready';
132
  statusEl.className = 'status status-ok';
 
142
  }
143
 
144
  function clearConversation() {
145
+ const modelType = systemPresetSelect.value;
146
+
147
  resetSession();
148
+ modelChats[modelType]["messages"] = [];
149
+ // If the model is google, we also have to clear the temperature
150
+ if (modelType === "google") {
151
+ modelChats["google"]["temperature"] = 0.2;
152
+ tempSlider.value = 0.2
153
+ updateSlidersUI();
154
+ }
155
  renderMessages();
156
  statusEl.textContent = 'Conversation cleared. Start a new chat!';
157
  statusEl.className = 'status status-ok';
 
184
  });
185
 
186
  tempSlider.addEventListener('input', () => {
187
+ if (!tempSlider.disabled) {
188
+ updateSlidersUI();
189
+ const modelType = systemPresetSelect.value;
190
+ modelChats[modelType]["temperature"] = tempSlider.value;
191
+ }
192
  });
193
  // maxTokensSlider.addEventListener("input", updateSlidersUI);
194
  clearBtn.addEventListener('click', clearConversation);
195
 
196
  systemPresetSelect.addEventListener('change', () => {
197
  updateTempControlForModel(); // 👈 add this
198
+ renderMessages();
199
+ statusEl.textContent = 'Model changed.';
200
  statusEl.className = 'status status-ok';
201
  });
202
 
 
220
  // Enable slider for other models
221
  tempSlider.disabled = false;
222
  tempSlider.classList.remove('disabled');
223
+ tempSlider.value = modelChats[model]["temperature"];
224
  updateSlidersUI(); // refresh displayed value
225
  }
226
  }