Spaces:
Sleeping
Sleeping
Update chat.py
Browse files
chat.py
CHANGED
|
@@ -10,17 +10,22 @@ _tokenizer = None
|
|
| 10 |
_model = None
|
| 11 |
|
| 12 |
def _load_chatbot_resources():
|
|
|
|
| 13 |
global _tokenizer, _model
|
| 14 |
if _tokenizer is None or _model is None:
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
|
| 17 |
return _tokenizer, _model
|
| 18 |
|
| 19 |
class ChatSessionManager:
|
|
|
|
| 20 |
def __init__(self) -> None:
|
| 21 |
self._sessions: Dict[str, Dict[str, object]] = {}
|
| 22 |
|
| 23 |
def _ensure_session(self, session_id: str | None) -> str:
|
|
|
|
| 24 |
if not session_id or session_id not in self._sessions:
|
| 25 |
session_id = uuid.uuid4().hex
|
| 26 |
self._sessions[session_id] = {
|
|
@@ -30,17 +35,19 @@ class ChatSessionManager:
|
|
| 30 |
return session_id
|
| 31 |
|
| 32 |
def _generate_reply(self, history_tokens, message):
|
|
|
|
| 33 |
tokenizer, model = _load_chatbot_resources()
|
| 34 |
-
|
|
|
|
| 35 |
user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
|
| 36 |
|
| 37 |
-
# Concatenate history and current input
|
| 38 |
if history_tokens is not None:
|
| 39 |
bot_input_ids = torch.cat([history_tokens, user_input_ids], dim=-1)
|
| 40 |
else:
|
| 41 |
bot_input_ids = user_input_ids
|
| 42 |
|
| 43 |
-
# Generate response
|
| 44 |
generated_ids = model.generate(
|
| 45 |
bot_input_ids,
|
| 46 |
max_length=1024,
|
|
@@ -51,13 +58,14 @@ class ChatSessionManager:
|
|
| 51 |
temperature=0.8
|
| 52 |
)
|
| 53 |
|
| 54 |
-
# Decode the reply
|
| 55 |
reply_ids = generated_ids[:, bot_input_ids.shape[-1]:]
|
| 56 |
reply_text = tokenizer.decode(reply_ids[0], skip_special_tokens=True)
|
| 57 |
|
| 58 |
return generated_ids, reply_text or "I am still thinking about that."
|
| 59 |
|
| 60 |
def handle_message(self, session_id: str | None, message: str) -> Tuple[str, str, List[Dict[str, str]]]:
|
|
|
|
| 61 |
session_id = self._ensure_session(session_id)
|
| 62 |
state = self._sessions[session_id]
|
| 63 |
transcript: List[Dict[str, str]] = state["transcript"] # type: ignore[assignment]
|
|
@@ -77,6 +85,7 @@ class ChatSessionManager:
|
|
| 77 |
return reply, session_id, list(transcript)
|
| 78 |
|
| 79 |
def get_history(self, session_id: str) -> List[Dict[str, str]]:
|
|
|
|
| 80 |
state = self._sessions.get(session_id, {"transcript": []})
|
| 81 |
return list(state["transcript"]) # type: ignore[index]
|
| 82 |
|
|
|
|
| 10 |
_model = None
|
| 11 |
|
| 12 |
def _load_chatbot_resources():
|
| 13 |
+
"""Loads the tokenizer and model resources only once."""
|
| 14 |
global _tokenizer, _model
|
| 15 |
if _tokenizer is None or _model is None:
|
| 16 |
+
# FIX: Added padding_side='left' for decoder-only models (like DialoGPT)
|
| 17 |
+
# to ensure correct token generation and suppress the warning.
|
| 18 |
+
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side='left')
|
| 19 |
_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
|
| 20 |
return _tokenizer, _model
|
| 21 |
|
| 22 |
class ChatSessionManager:
|
| 23 |
+
"""Manages chat sessions, history tokens, and conversation transcripts."""
|
| 24 |
def __init__(self) -> None:
|
| 25 |
self._sessions: Dict[str, Dict[str, object]] = {}
|
| 26 |
|
| 27 |
def _ensure_session(self, session_id: str | None) -> str:
|
| 28 |
+
"""Ensures a valid session ID exists, creating a new one if necessary."""
|
| 29 |
if not session_id or session_id not in self._sessions:
|
| 30 |
session_id = uuid.uuid4().hex
|
| 31 |
self._sessions[session_id] = {
|
|
|
|
| 35 |
return session_id
|
| 36 |
|
| 37 |
def _generate_reply(self, history_tokens, message):
|
| 38 |
+
"""Encodes input, generates a response using the model, and decodes the result."""
|
| 39 |
tokenizer, model = _load_chatbot_resources()
|
| 40 |
+
|
| 41 |
+
# 1. Encode user message
|
| 42 |
user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
|
| 43 |
|
| 44 |
+
# 2. Concatenate history and current input
|
| 45 |
if history_tokens is not None:
|
| 46 |
bot_input_ids = torch.cat([history_tokens, user_input_ids], dim=-1)
|
| 47 |
else:
|
| 48 |
bot_input_ids = user_input_ids
|
| 49 |
|
| 50 |
+
# 3. Generate response
|
| 51 |
generated_ids = model.generate(
|
| 52 |
bot_input_ids,
|
| 53 |
max_length=1024,
|
|
|
|
| 58 |
temperature=0.8
|
| 59 |
)
|
| 60 |
|
| 61 |
+
# 4. Decode the reply (only the new part)
|
| 62 |
reply_ids = generated_ids[:, bot_input_ids.shape[-1]:]
|
| 63 |
reply_text = tokenizer.decode(reply_ids[0], skip_special_tokens=True)
|
| 64 |
|
| 65 |
return generated_ids, reply_text or "I am still thinking about that."
|
| 66 |
|
| 67 |
def handle_message(self, session_id: str | None, message: str) -> Tuple[str, str, List[Dict[str, str]]]:
|
| 68 |
+
"""Processes an incoming message, generates a reply, and updates session history."""
|
| 69 |
session_id = self._ensure_session(session_id)
|
| 70 |
state = self._sessions[session_id]
|
| 71 |
transcript: List[Dict[str, str]] = state["transcript"] # type: ignore[assignment]
|
|
|
|
| 85 |
return reply, session_id, list(transcript)
|
| 86 |
|
| 87 |
def get_history(self, session_id: str) -> List[Dict[str, str]]:
|
| 88 |
+
"""Retrieves the full transcript for a given session ID."""
|
| 89 |
state = self._sessions.get(session_id, {"transcript": []})
|
| 90 |
return list(state["transcript"]) # type: ignore[index]
|
| 91 |
|