Rephrasia / chat.py
RamizXhah's picture
Upload 13 files
58373e3 verified
"""DialoGPT-powered chatbot with session history."""
from __future__ import annotations
import uuid
from typing import Dict, List, Tuple
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "microsoft/DialoGPT-medium"
_tokenizer = None
_model = None
def _load_chatbot_resources():
global _tokenizer, _model
if _tokenizer is None or _model is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
return _tokenizer, _model
class ChatSessionManager:
def __init__(self) -> None:
self._sessions: Dict[str, Dict[str, object]] = {}
def _ensure_session(self, session_id: str | None) -> str:
if not session_id or session_id not in self._sessions:
session_id = uuid.uuid4().hex
self._sessions[session_id] = {
"history_tokens": None,
"transcript": []
return session_id
def _generate_reply(self, history_tokens, message):
tokenizer, model = _load_chatbot_resources()
user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
if history_tokens is not None:
bot_input_ids = torch.cat([history_tokens, user_input_ids], dim=-1)
else:
bot_input_ids = user_input_ids
generated_ids = model.generate(
bot_input_ids,
max_length=1024,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.8
)
reply_ids = generated_ids[:, bot_input_ids.shape[-1]:]
reply_text = tokenizer.decode(reply_ids[0], skip_special_tokens=True)
return generated_ids, reply_text or "I am still thinking about that."
def handle_message(self, session_id: str | None, message: str) -> Tuple[str, str, List[Dict[str, str]]]:
session_id = self._ensure_session(session_id)
state = self._sessions[session_id]
transcript: List[Dict[str, str]] = state["transcript"] # type: ignore[assignment]
try:
updated_tokens, reply = self._generate_reply(state["history_tokens"], message)
state["history_tokens"] = updated_tokens
except Exception as exc:
reply = "Sorry, I ran into an issue generating a response."
transcript.append({"role": "system", "message": str(exc)})
transcript.append({"role": "user", "message": message})
transcript.append({"role": "assistant", "message": reply})
return reply, session_id, list(transcript)
def get_history(self, session_id: str) -> List[Dict[str, str]]:
state = self._sessions.get(session_id, {"transcript": []})
return list(state["transcript"]) # type: ignore[index]
chat_manager = ChatSessionManager()
"""DialoGPT-powered chatbot with session history."""
from __future__ import annotations
import uuid
from typing import Dict, List, Tuple
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_NAME = "microsoft/DialoGPT-medium"
_tokenizer = None
_model = None
def _load_chatbot_resources():
global _tokenizer, _model
if _tokenizer is None or _model is None:
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
return _tokenizer, _model
class ChatSessionManager:
def __init__(self) -> None:
self._sessions: Dict[str, Dict[str, object]] = {}
def _ensure_session(self, session_id: str | None) -> str:
if not session_id or session_id not in self._sessions:
session_id = uuid.uuid4().hex
self._sessions[session_id] = {
"history_tokens": None,
"transcript": []
}
return session_id
def _generate_reply(self, history_tokens, message):
tokenizer, model = _load_chatbot_resources()
user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
if history_tokens is not None:
bot_input_ids = torch.cat([history_tokens, user_input_ids], dim=-1)
else:
bot_input_ids = user_input_ids
generated_ids = model.generate(
bot_input_ids,
max_length=1024,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.8
)
reply_ids = generated_ids[:, bot_input_ids.shape[-1]:]
reply_text = tokenizer.decode(reply_ids[0], skip_special_tokens=True)
return generated_ids, reply_text or "I am still thinking about that."
def handle_message(self, session_id: str | None, message: str) -> Tuple[str, str, List[Dict[str, str]]]:
session_id = self._ensure_session(session_id)
state = self._sessions[session_id]
transcript: List[Dict[str, str]] = state["transcript"] # type: ignore[assignment]
try:
updated_tokens, reply = self._generate_reply(state["history_tokens"], message)
state["history_tokens"] = updated_tokens
except Exception as exc:
reply = "Sorry, I ran into an issue generating a response."
transcript.append({"role": "system", "message": str(exc)})
transcript.append({"role": "user", "message": message})
transcript.append({"role": "assistant", "message": reply})
return reply, session_id, list(transcript)
def get_history(self, session_id: str) -> List[Dict[str, str]]:
state = self._sessions.get(session_id, {"transcript": []})
return list(state["transcript"]) # type: ignore[index]
chat_manager = ChatSessionManager()