EDA-Generator / chatbot.py
ohmp's picture
Upload folder using huggingface_hub
043c791 verified
"""
Chatbot module using HuggingFace Transformers.
Uses gpt-oss-20b model with AutoModelForCausalLM, AutoTokenizer, and chat templates.
"""
import os
import torch
from typing import Generator
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
load_dotenv()
# Model configuration
MODEL_ID = "openai/gpt-oss-20b"
class Chatbot:
"""
A chatbot class that uses HuggingFace Transformers
with AutoModelForCausalLM and AutoTokenizer for text generation.
"""
def __init__(self, model_id: str = MODEL_ID):
"""Initialize the chatbot with the specified model."""
self.model_id = model_id
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
model_id,
# token=os.getenv("HF_TOKEN"),
trust_remote_code=True
)
# Set pad token if not set
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with appropriate settings
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
# token=os.getenv("HF_TOKEN"),
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
low_cpu_mem_usage=True
)
if not torch.cuda.is_available():
self.model = self.model.to(self.device)
self.model.eval()
self.system_prompt = (
"You are a helpful, friendly AI assistant. "
"You provide clear, accurate, and concise responses. "
"You can help with various tasks including coding, analysis, and general questions."
)
def _format_messages(self, message: str, history: list) -> list:
"""
Format the conversation history into the chat template format.
Args:
message: The current user message
history: List of [user_msg, assistant_msg] pairs
Returns:
List of message dictionaries for the chat template
"""
messages = [{"role": "system", "content": self.system_prompt}]
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
return messages
def chat(self, message: str, history: list) -> str:
"""
Generate a response to the user's message using transformers.
Args:
message: The user's input message
history: Conversation history as list of [user, assistant] pairs
Returns:
The assistant's response
"""
messages = self._format_messages(message, history)
try:
# Apply chat template
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=4096
).to(self.device)
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=1024,
temperature=0.7,
top_p=0.95,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode response (only the new tokens)
response = self.tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
return response.strip()
except Exception as e:
return f"Error generating response: {str(e)}"
def chat_stream(self, message: str, history: list) -> Generator[str, None, None]:
"""
Stream a response to the user's message for better UX.
Args:
message: The user's input message
history: Conversation history
Yields:
Chunks of the response as they are generated
"""
messages = self._format_messages(message, history)
try:
# Apply chat template
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
inputs = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=4096
).to(self.device)
# Create streamer
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
# Generation kwargs
generation_kwargs = dict(
**inputs,
max_new_tokens=1024,
temperature=0.7,
top_p=0.95,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
streamer=streamer
)
# Run generation in a separate thread
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the response
response = ""
for new_text in streamer:
response += new_text
yield response.strip()
thread.join()
except Exception as e:
yield f"Error generating response: {str(e)}"
# Create a default chatbot instance (lazy loading)
_chatbot = None
def get_chatbot() -> Chatbot:
"""Get or create the chatbot instance."""
global _chatbot
if _chatbot is None:
_chatbot = Chatbot()
return _chatbot
def chat_fn(message: str, history: list) -> str:
"""
Function to be used with Gradio ChatInterface.
Args:
message: User's input message
history: Conversation history
Returns:
Assistant's response
"""
return get_chatbot().chat(message, history)
def chat_stream_fn(message: str, history: list) -> Generator[str, None, None]:
"""
Streaming function for Gradio ChatInterface.
Args:
message: User's input message
history: Conversation history
Yields:
Response chunks
"""
yield from get_chatbot().chat_stream(message, history)