Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| from typing import List, Tuple | |
| import os | |
| # Use a very small model for testing | |
| class UniversalChatModel: | |
| def __init__(self, model_name: str): | |
| self.model_name = model_name | |
| print(f"Loading tokenizer for {model_name}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| # Set padding token | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token or "<|endoftext|>" | |
| print(f"Loading model {model_name}...") | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| token=os.getenv("HF_TOKEN") | |
| ) | |
| print("Model loaded successfully!") | |
| def format_prompt_fallback(self, messages: List[dict]) -> str: | |
| """Universal ChatML format for models without chat templates""" | |
| chatml = "" | |
| for message in messages: | |
| role = message["role"] | |
| content = message["content"] | |
| chatml += f"<|im_start|>{role}\n{content}<|im_end|>\n" | |
| chatml += "<|im_start|>assistant\n" | |
| return chatml | |
| def build_messages(self, history: List[Tuple[str, str]], current_message: str, system_prompt: str = None) -> List[dict]: | |
| """Build universal message format""" | |
| messages = [] | |
| # Add system prompt | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # Add history | |
| for user_msg, assistant_msg in history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add current message | |
| messages.append({"role": "user", "content": current_message}) | |
| return messages | |
| def format_prompt(self, messages: List[dict]) -> str: | |
| """Format prompt using model's chat template or fallback""" | |
| # Try model's built-in chat template | |
| if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template: | |
| try: | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| return prompt | |
| except Exception as e: | |
| print(f"Warning: Failed to use chat template: {e}") | |
| pass | |
| # Fallback to ChatML | |
| return self.format_prompt_fallback(messages) | |
| def extract_response(self, prompt: str, generated_text: str) -> str: | |
| """Enhanced universal response extraction with comprehensive token cleanup""" | |
| # Define token patterns to clean | |
| token_patterns = [ | |
| # ChatML tokens | |
| "<|im_start|>", "<|im_end|>", | |
| # Common end tokens | |
| "</s>", "<eos>", | |
| # Model-specific tokens (from your example) | |
| "[/inst]", "[inst]", "<</sys>>", "<</inst>>", | |
| # Additional patterns | |
| "[/assistant]", "[assistant]", "[/user]", "[user]" | |
| ] | |
| def clean_response(response: str) -> str: | |
| """Clean all known token patterns from response""" | |
| for pattern in token_patterns: | |
| response = response.replace(pattern, "") | |
| # Clean up extra whitespace and newlines | |
| response = response.strip() | |
| # Remove leading/trailing punctuation that might be from tokens | |
| response = response.strip('.,!?;:\n\r\t ') | |
| return response | |
| def extract_chatml_response(text: str) -> str | None: | |
| """Extract response using ChatML markers""" | |
| if "<|im_start|>assistant\n" in text: | |
| parts = text.split("<|im_start|>assistant\n") | |
| if len(parts) > 1: | |
| response = parts[-1] | |
| if "<|im_end|>" in response: | |
| response = response.split("<|im_end|>")[0] | |
| return clean_response(response) | |
| return None | |
| def extract_inst_response(text: str) -> str | None: | |
| """Extract response using inst/inst pattern (from your example)""" | |
| # Look for [inst] or [/inst] patterns | |
| inst_patterns = ["[inst]", "[/inst]"] | |
| for pattern in inst_patterns: | |
| if pattern in text.lower(): | |
| parts = text.lower().split(pattern) | |
| if len(parts) > 1: | |
| response = parts[-1] | |
| return clean_response(response) | |
| return None | |
| def extract_after_prompt(text: str, prompt: str) -> str | None: | |
| """Extract response that comes directly after prompt""" | |
| if text.startswith(prompt): | |
| response = text[len(prompt):] | |
| return clean_response(response) | |
| return None | |
| def extract_last_assistant_message(text: str) -> str: | |
| """Fallback: find the last assistant-like message""" | |
| # Look for various assistant indicators | |
| assistant_indicators = ["assistant", "[inst]", "[/inst]"] | |
| text_lower = text.lower() | |
| best_response = "" | |
| for indicator in assistant_indicators: | |
| if indicator in text_lower: | |
| parts = text_lower.split(indicator) | |
| if len(parts) > 1: | |
| candidate = parts[-1] | |
| if len(candidate) > len(best_response): | |
| best_response = candidate | |
| if best_response: | |
| return clean_response(best_response) | |
| # Final fallback: just clean the whole thing | |
| return clean_response(text) | |
| print(f"\n\n----- EXTRACTION DEBUG -----\n") | |
| print(f"Generated text length: {len(generated_text)}") | |
| print(f"Prompt length: {len(prompt)}") | |
| print(f"Generated starts with prompt: {generated_text.startswith(prompt)}") | |
| # Try extraction methods in order of reliability | |
| response = None | |
| # Method 1: ChatML extraction | |
| response = extract_chatml_response(generated_text) | |
| if response: | |
| print("Used ChatML extraction") | |
| else: | |
| # Method 2: inst pattern extraction (from your example) | |
| response = extract_inst_response(generated_text) | |
| if response: | |
| print("Used inst pattern extraction") | |
| else: | |
| # Method 3: Extract after prompt | |
| response = extract_after_prompt(generated_text, prompt) | |
| if response: | |
| print("Used extract-after-prompt method") | |
| else: | |
| # Method 4: Fallback to last assistant message | |
| response = extract_last_assistant_message(generated_text) | |
| print("Used fallback extraction") | |
| print(f"Extracted response: '{response}'") | |
| print(f"Response length: {len(response)}") | |
| print(f"----- END EXTRACTION DEBUG -----\n\n") | |
| return response or "" | |
| def generate(self, message: str, history: List[Tuple[str, str]] | None = None, system_prompt: str | None = None) -> str: | |
| """Generate response using universal chat template system""" | |
| if history is None: | |
| history = [] | |
| # Build messages | |
| messages = self.build_messages(history, message, system_prompt) | |
| # Format prompt | |
| prompt = self.format_prompt(messages) | |
| print(f"\n\n----- PROMPT -----\n{prompt}\n-----------------\n\n") | |
| # Tokenize | |
| inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) | |
| # Move to model device | |
| inputs = {k: v.to(self.model.device) for k, v in inputs.items()} | |
| # Generate | |
| generation_config = { | |
| "max_new_tokens": 150, | |
| "do_sample": True, | |
| "temperature": 0.7, | |
| "pad_token_id": self.tokenizer.eos_token_id, | |
| "eos_token_id": self.tokenizer.eos_token_id, | |
| } | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| **generation_config | |
| ) | |
| # Decode | |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| print(f"\n\n----- GENERATED -----\n{generated_text}\n-------------------\n\n") | |
| # Extract response | |
| response = self.extract_response(prompt, generated_text) | |
| return response | |
| # Initialize with a tiny model for testing | |
| # MODEL_NAME = "HuggingFaceH4/tiny-random-LlamaForCausalLM" | |
| MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf" | |
| SYSTEM_PROMPT = "You are a helpful assistant." | |
| # Create global model instance | |
| print("Creating model instance...") | |
| chat_model = UniversalChatModel(MODEL_NAME) | |
| def generate(message: str, history: List[Tuple[str, str]]) -> str: | |
| """Generate response using universal chat model""" | |
| return chat_model.generate(message, history, SYSTEM_PROMPT) | |
| if __name__ == "__main__": | |
| # Quick test | |
| print("Testing generation...") | |
| reply = generate("What is 2+2?", []) | |
| print(f"Final response: {reply}") | |