from transformers import AutoTokenizer, AutoModelForCausalLM import torch from typing import List, Tuple import os # Use a very small model for testing class UniversalChatModel: def __init__(self, model_name: str): self.model_name = model_name print(f"Loading tokenizer for {model_name}...") self.tokenizer = AutoTokenizer.from_pretrained( model_name, token=os.getenv("HF_TOKEN") ) # Set padding token if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token or "<|endoftext|>" print(f"Loading model {model_name}...") self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto", token=os.getenv("HF_TOKEN") ) print("Model loaded successfully!") def format_prompt_fallback(self, messages: List[dict]) -> str: """Universal ChatML format for models without chat templates""" chatml = "" for message in messages: role = message["role"] content = message["content"] chatml += f"<|im_start|>{role}\n{content}<|im_end|>\n" chatml += "<|im_start|>assistant\n" return chatml def build_messages(self, history: List[Tuple[str, str]], current_message: str, system_prompt: str = None) -> List[dict]: """Build universal message format""" messages = [] # Add system prompt if system_prompt: messages.append({"role": "system", "content": system_prompt}) # Add history for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) # Add current message messages.append({"role": "user", "content": current_message}) return messages def format_prompt(self, messages: List[dict]) -> str: """Format prompt using model's chat template or fallback""" # Try model's built-in chat template if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template: try: prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return prompt except Exception as e: print(f"Warning: Failed to use chat template: {e}") pass # Fallback to ChatML return self.format_prompt_fallback(messages) def extract_response(self, prompt: str, generated_text: str) -> str: """Enhanced universal response extraction with comprehensive token cleanup""" # Define token patterns to clean token_patterns = [ # ChatML tokens "<|im_start|>", "<|im_end|>", # Common end tokens "", "", # Model-specific tokens (from your example) "[/inst]", "[inst]", "<>", "<>", # Additional patterns "[/assistant]", "[assistant]", "[/user]", "[user]" ] def clean_response(response: str) -> str: """Clean all known token patterns from response""" for pattern in token_patterns: response = response.replace(pattern, "") # Clean up extra whitespace and newlines response = response.strip() # Remove leading/trailing punctuation that might be from tokens response = response.strip('.,!?;:\n\r\t ') return response def extract_chatml_response(text: str) -> str | None: """Extract response using ChatML markers""" if "<|im_start|>assistant\n" in text: parts = text.split("<|im_start|>assistant\n") if len(parts) > 1: response = parts[-1] if "<|im_end|>" in response: response = response.split("<|im_end|>")[0] return clean_response(response) return None def extract_inst_response(text: str) -> str | None: """Extract response using inst/inst pattern (from your example)""" # Look for [inst] or [/inst] patterns inst_patterns = ["[inst]", "[/inst]"] for pattern in inst_patterns: if pattern in text.lower(): parts = text.lower().split(pattern) if len(parts) > 1: response = parts[-1] return clean_response(response) return None def extract_after_prompt(text: str, prompt: str) -> str | None: """Extract response that comes directly after prompt""" if text.startswith(prompt): response = text[len(prompt):] return clean_response(response) return None def extract_last_assistant_message(text: str) -> str: """Fallback: find the last assistant-like message""" # Look for various assistant indicators assistant_indicators = ["assistant", "[inst]", "[/inst]"] text_lower = text.lower() best_response = "" for indicator in assistant_indicators: if indicator in text_lower: parts = text_lower.split(indicator) if len(parts) > 1: candidate = parts[-1] if len(candidate) > len(best_response): best_response = candidate if best_response: return clean_response(best_response) # Final fallback: just clean the whole thing return clean_response(text) print(f"\n\n----- EXTRACTION DEBUG -----\n") print(f"Generated text length: {len(generated_text)}") print(f"Prompt length: {len(prompt)}") print(f"Generated starts with prompt: {generated_text.startswith(prompt)}") # Try extraction methods in order of reliability response = None # Method 1: ChatML extraction response = extract_chatml_response(generated_text) if response: print("Used ChatML extraction") else: # Method 2: inst pattern extraction (from your example) response = extract_inst_response(generated_text) if response: print("Used inst pattern extraction") else: # Method 3: Extract after prompt response = extract_after_prompt(generated_text, prompt) if response: print("Used extract-after-prompt method") else: # Method 4: Fallback to last assistant message response = extract_last_assistant_message(generated_text) print("Used fallback extraction") print(f"Extracted response: '{response}'") print(f"Response length: {len(response)}") print(f"----- END EXTRACTION DEBUG -----\n\n") return response or "" def generate(self, message: str, history: List[Tuple[str, str]] | None = None, system_prompt: str | None = None) -> str: """Generate response using universal chat template system""" if history is None: history = [] # Build messages messages = self.build_messages(history, message, system_prompt) # Format prompt prompt = self.format_prompt(messages) print(f"\n\n----- PROMPT -----\n{prompt}\n-----------------\n\n") # Tokenize inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) # Move to model device inputs = {k: v.to(self.model.device) for k, v in inputs.items()} # Generate generation_config = { "max_new_tokens": 150, "do_sample": True, "temperature": 0.7, "pad_token_id": self.tokenizer.eos_token_id, "eos_token_id": self.tokenizer.eos_token_id, } with torch.no_grad(): outputs = self.model.generate( **inputs, **generation_config ) # Decode generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"\n\n----- GENERATED -----\n{generated_text}\n-------------------\n\n") # Extract response response = self.extract_response(prompt, generated_text) return response # Initialize with a tiny model for testing # MODEL_NAME = "HuggingFaceH4/tiny-random-LlamaForCausalLM" MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" SYSTEM_PROMPT = "You are a helpful assistant." # Create global model instance print("Creating model instance...") chat_model = UniversalChatModel(MODEL_NAME) def generate(message: str, history: List[Tuple[str, str]]) -> str: """Generate response using universal chat model""" return chat_model.generate(message, history, SYSTEM_PROMPT) if __name__ == "__main__": # Quick test print("Testing generation...") reply = generate("What is 2+2?", []) print(f"Final response: {reply}")