Spaces:
Runtime error
Runtime error
| """ | |
| Chatbot module using HuggingFace Transformers. | |
| Uses gpt-oss-20b model with AutoModelForCausalLM, AutoTokenizer, and chat templates. | |
| """ | |
| import os | |
| import torch | |
| from typing import Generator | |
| from dotenv import load_dotenv | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from threading import Thread | |
| load_dotenv() | |
| # Model configuration | |
| MODEL_ID = "openai/gpt-oss-20b" | |
| class Chatbot: | |
| """ | |
| A chatbot class that uses HuggingFace Transformers | |
| with AutoModelForCausalLM and AutoTokenizer for text generation. | |
| """ | |
| def __init__(self, model_id: str = MODEL_ID): | |
| """Initialize the chatbot with the specified model.""" | |
| self.model_id = model_id | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| # token=os.getenv("HF_TOKEN"), | |
| trust_remote_code=True | |
| ) | |
| # Set pad token if not set | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Load model with appropriate settings | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| # token=os.getenv("HF_TOKEN"), | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True | |
| ) | |
| if not torch.cuda.is_available(): | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| self.system_prompt = ( | |
| "You are a helpful, friendly AI assistant. " | |
| "You provide clear, accurate, and concise responses. " | |
| "You can help with various tasks including coding, analysis, and general questions." | |
| ) | |
| def _format_messages(self, message: str, history: list) -> list: | |
| """ | |
| Format the conversation history into the chat template format. | |
| Args: | |
| message: The current user message | |
| history: List of [user_msg, assistant_msg] pairs | |
| Returns: | |
| List of message dictionaries for the chat template | |
| """ | |
| messages = [{"role": "system", "content": self.system_prompt}] | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| return messages | |
| def chat(self, message: str, history: list) -> str: | |
| """ | |
| Generate a response to the user's message using transformers. | |
| Args: | |
| message: The user's input message | |
| history: Conversation history as list of [user, assistant] pairs | |
| Returns: | |
| The assistant's response | |
| """ | |
| messages = self._format_messages(message, history) | |
| try: | |
| # Apply chat template | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize input | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=4096 | |
| ).to(self.device) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=1024, | |
| temperature=0.7, | |
| top_p=0.95, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # Decode response (only the new tokens) | |
| response = self.tokenizer.decode( | |
| outputs[0][inputs['input_ids'].shape[1]:], | |
| skip_special_tokens=True | |
| ) | |
| return response.strip() | |
| except Exception as e: | |
| return f"Error generating response: {str(e)}" | |
| def chat_stream(self, message: str, history: list) -> Generator[str, None, None]: | |
| """ | |
| Stream a response to the user's message for better UX. | |
| Args: | |
| message: The user's input message | |
| history: Conversation history | |
| Yields: | |
| Chunks of the response as they are generated | |
| """ | |
| messages = self._format_messages(message, history) | |
| try: | |
| # Apply chat template | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize input | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=4096 | |
| ).to(self.device) | |
| # Create streamer | |
| streamer = TextIteratorStreamer( | |
| self.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True | |
| ) | |
| # Generation kwargs | |
| generation_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=1024, | |
| temperature=0.7, | |
| top_p=0.95, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| streamer=streamer | |
| ) | |
| # Run generation in a separate thread | |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Stream the response | |
| response = "" | |
| for new_text in streamer: | |
| response += new_text | |
| yield response.strip() | |
| thread.join() | |
| except Exception as e: | |
| yield f"Error generating response: {str(e)}" | |
| # Create a default chatbot instance (lazy loading) | |
| _chatbot = None | |
| def get_chatbot() -> Chatbot: | |
| """Get or create the chatbot instance.""" | |
| global _chatbot | |
| if _chatbot is None: | |
| _chatbot = Chatbot() | |
| return _chatbot | |
| def chat_fn(message: str, history: list) -> str: | |
| """ | |
| Function to be used with Gradio ChatInterface. | |
| Args: | |
| message: User's input message | |
| history: Conversation history | |
| Returns: | |
| Assistant's response | |
| """ | |
| return get_chatbot().chat(message, history) | |
| def chat_stream_fn(message: str, history: list) -> Generator[str, None, None]: | |
| """ | |
| Streaming function for Gradio ChatInterface. | |
| Args: | |
| message: User's input message | |
| history: Conversation history | |
| Yields: | |
| Response chunks | |
| """ | |
| yield from get_chatbot().chat_stream(message, history) | |