RamizXhah commited on
Commit
3345e2e
·
verified ·
1 Parent(s): 0897989

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +14 -5
chat.py CHANGED
@@ -10,17 +10,22 @@ _tokenizer = None
10
  _model = None
11
 
12
  def _load_chatbot_resources():
 
13
  global _tokenizer, _model
14
  if _tokenizer is None or _model is None:
15
- _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
16
  _model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
17
  return _tokenizer, _model
18
 
19
  class ChatSessionManager:
 
20
  def __init__(self) -> None:
21
  self._sessions: Dict[str, Dict[str, object]] = {}
22
 
23
  def _ensure_session(self, session_id: str | None) -> str:
 
24
  if not session_id or session_id not in self._sessions:
25
  session_id = uuid.uuid4().hex
26
  self._sessions[session_id] = {
@@ -30,17 +35,19 @@ class ChatSessionManager:
30
  return session_id
31
 
32
  def _generate_reply(self, history_tokens, message):
 
33
  tokenizer, model = _load_chatbot_resources()
34
- # Encode user message
 
35
  user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
36
 
37
- # Concatenate history and current input
38
  if history_tokens is not None:
39
  bot_input_ids = torch.cat([history_tokens, user_input_ids], dim=-1)
40
  else:
41
  bot_input_ids = user_input_ids
42
 
43
- # Generate response
44
  generated_ids = model.generate(
45
  bot_input_ids,
46
  max_length=1024,
@@ -51,13 +58,14 @@ class ChatSessionManager:
51
  temperature=0.8
52
  )
53
 
54
- # Decode the reply
55
  reply_ids = generated_ids[:, bot_input_ids.shape[-1]:]
56
  reply_text = tokenizer.decode(reply_ids[0], skip_special_tokens=True)
57
 
58
  return generated_ids, reply_text or "I am still thinking about that."
59
 
60
  def handle_message(self, session_id: str | None, message: str) -> Tuple[str, str, List[Dict[str, str]]]:
 
61
  session_id = self._ensure_session(session_id)
62
  state = self._sessions[session_id]
63
  transcript: List[Dict[str, str]] = state["transcript"] # type: ignore[assignment]
@@ -77,6 +85,7 @@ class ChatSessionManager:
77
  return reply, session_id, list(transcript)
78
 
79
  def get_history(self, session_id: str) -> List[Dict[str, str]]:
 
80
  state = self._sessions.get(session_id, {"transcript": []})
81
  return list(state["transcript"]) # type: ignore[index]
82
 
 
10
  _model = None
11
 
12
  def _load_chatbot_resources():
13
+ """Loads the tokenizer and model resources only once."""
14
  global _tokenizer, _model
15
  if _tokenizer is None or _model is None:
16
+ # FIX: Added padding_side='left' for decoder-only models (like DialoGPT)
17
+ # to ensure correct token generation and suppress the warning.
18
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side='left')
19
  _model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
20
  return _tokenizer, _model
21
 
22
  class ChatSessionManager:
23
+ """Manages chat sessions, history tokens, and conversation transcripts."""
24
  def __init__(self) -> None:
25
  self._sessions: Dict[str, Dict[str, object]] = {}
26
 
27
  def _ensure_session(self, session_id: str | None) -> str:
28
+ """Ensures a valid session ID exists, creating a new one if necessary."""
29
  if not session_id or session_id not in self._sessions:
30
  session_id = uuid.uuid4().hex
31
  self._sessions[session_id] = {
 
35
  return session_id
36
 
37
  def _generate_reply(self, history_tokens, message):
38
+ """Encodes input, generates a response using the model, and decodes the result."""
39
  tokenizer, model = _load_chatbot_resources()
40
+
41
+ # 1. Encode user message
42
  user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
43
 
44
+ # 2. Concatenate history and current input
45
  if history_tokens is not None:
46
  bot_input_ids = torch.cat([history_tokens, user_input_ids], dim=-1)
47
  else:
48
  bot_input_ids = user_input_ids
49
 
50
+ # 3. Generate response
51
  generated_ids = model.generate(
52
  bot_input_ids,
53
  max_length=1024,
 
58
  temperature=0.8
59
  )
60
 
61
+ # 4. Decode the reply (only the new part)
62
  reply_ids = generated_ids[:, bot_input_ids.shape[-1]:]
63
  reply_text = tokenizer.decode(reply_ids[0], skip_special_tokens=True)
64
 
65
  return generated_ids, reply_text or "I am still thinking about that."
66
 
67
  def handle_message(self, session_id: str | None, message: str) -> Tuple[str, str, List[Dict[str, str]]]:
68
+ """Processes an incoming message, generates a reply, and updates session history."""
69
  session_id = self._ensure_session(session_id)
70
  state = self._sessions[session_id]
71
  transcript: List[Dict[str, str]] = state["transcript"] # type: ignore[assignment]
 
85
  return reply, session_id, list(transcript)
86
 
87
  def get_history(self, session_id: str) -> List[Dict[str, str]]:
88
+ """Retrieves the full transcript for a given session ID."""
89
  state = self._sessions.get(session_id, {"transcript": []})
90
  return list(state["transcript"]) # type: ignore[index]
91