""" Chatbot module using HuggingFace Transformers. Uses gpt-oss-20b model with AutoModelForCausalLM, AutoTokenizer, and chat templates. """ import os import torch from typing import Generator from dotenv import load_dotenv from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread load_dotenv() # Model configuration MODEL_ID = "openai/gpt-oss-20b" class Chatbot: """ A chatbot class that uses HuggingFace Transformers with AutoModelForCausalLM and AutoTokenizer for text generation. """ def __init__(self, model_id: str = MODEL_ID): """Initialize the chatbot with the specified model.""" self.model_id = model_id self.device = "cuda" if torch.cuda.is_available() else "cpu" # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( model_id, # token=os.getenv("HF_TOKEN"), trust_remote_code=True ) # Set pad token if not set if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load model with appropriate settings self.model = AutoModelForCausalLM.from_pretrained( model_id, # token=os.getenv("HF_TOKEN"), torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, low_cpu_mem_usage=True ) if not torch.cuda.is_available(): self.model = self.model.to(self.device) self.model.eval() self.system_prompt = ( "You are a helpful, friendly AI assistant. " "You provide clear, accurate, and concise responses. " "You can help with various tasks including coding, analysis, and general questions." ) def _format_messages(self, message: str, history: list) -> list: """ Format the conversation history into the chat template format. Args: message: The current user message history: List of [user_msg, assistant_msg] pairs Returns: List of message dictionaries for the chat template """ messages = [{"role": "system", "content": self.system_prompt}] for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) return messages def chat(self, message: str, history: list) -> str: """ Generate a response to the user's message using transformers. Args: message: The user's input message history: Conversation history as list of [user, assistant] pairs Returns: The assistant's response """ messages = self._format_messages(message, history) try: # Apply chat template prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize input inputs = self.tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=4096 ).to(self.device) # Generate response with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=1024, temperature=0.7, top_p=0.95, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id ) # Decode response (only the new tokens) response = self.tokenizer.decode( outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True ) return response.strip() except Exception as e: return f"Error generating response: {str(e)}" def chat_stream(self, message: str, history: list) -> Generator[str, None, None]: """ Stream a response to the user's message for better UX. Args: message: The user's input message history: Conversation history Yields: Chunks of the response as they are generated """ messages = self._format_messages(message, history) try: # Apply chat template prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize input inputs = self.tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=4096 ).to(self.device) # Create streamer streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) # Generation kwargs generation_kwargs = dict( **inputs, max_new_tokens=1024, temperature=0.7, top_p=0.95, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, streamer=streamer ) # Run generation in a separate thread thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() # Stream the response response = "" for new_text in streamer: response += new_text yield response.strip() thread.join() except Exception as e: yield f"Error generating response: {str(e)}" # Create a default chatbot instance (lazy loading) _chatbot = None def get_chatbot() -> Chatbot: """Get or create the chatbot instance.""" global _chatbot if _chatbot is None: _chatbot = Chatbot() return _chatbot def chat_fn(message: str, history: list) -> str: """ Function to be used with Gradio ChatInterface. Args: message: User's input message history: Conversation history Returns: Assistant's response """ return get_chatbot().chat(message, history) def chat_stream_fn(message: str, history: list) -> Generator[str, None, None]: """ Streaming function for Gradio ChatInterface. Args: message: User's input message history: Conversation history Yields: Response chunks """ yield from get_chatbot().chat_stream(message, history)