ai-hello-world / ai_engine.py
SubZtep's picture
fix: response parsing
8c0f998
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List, Tuple
import os
# Use a very small model for testing
class UniversalChatModel:
def __init__(self, model_name: str):
self.model_name = model_name
print(f"Loading tokenizer for {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=os.getenv("HF_TOKEN")
)
# Set padding token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token or "<|endoftext|>"
print(f"Loading model {model_name}...")
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
token=os.getenv("HF_TOKEN")
)
print("Model loaded successfully!")
def format_prompt_fallback(self, messages: List[dict]) -> str:
"""Universal ChatML format for models without chat templates"""
chatml = ""
for message in messages:
role = message["role"]
content = message["content"]
chatml += f"<|im_start|>{role}\n{content}<|im_end|>\n"
chatml += "<|im_start|>assistant\n"
return chatml
def build_messages(self, history: List[Tuple[str, str]], current_message: str, system_prompt: str = None) -> List[dict]:
"""Build universal message format"""
messages = []
# Add system prompt
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Add history
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": current_message})
return messages
def format_prompt(self, messages: List[dict]) -> str:
"""Format prompt using model's chat template or fallback"""
# Try model's built-in chat template
if hasattr(self.tokenizer, 'chat_template') and self.tokenizer.chat_template:
try:
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
return prompt
except Exception as e:
print(f"Warning: Failed to use chat template: {e}")
pass
# Fallback to ChatML
return self.format_prompt_fallback(messages)
def extract_response(self, prompt: str, generated_text: str) -> str:
"""Enhanced universal response extraction with comprehensive token cleanup"""
# Define token patterns to clean
token_patterns = [
# ChatML tokens
"<|im_start|>", "<|im_end|>",
# Common end tokens
"</s>", "<eos>",
# Model-specific tokens (from your example)
"[/inst]", "[inst]", "<</sys>>", "<</inst>>",
# Additional patterns
"[/assistant]", "[assistant]", "[/user]", "[user]"
]
def clean_response(response: str) -> str:
"""Clean all known token patterns from response"""
for pattern in token_patterns:
response = response.replace(pattern, "")
# Clean up extra whitespace and newlines
response = response.strip()
# Remove leading/trailing punctuation that might be from tokens
response = response.strip('.,!?;:\n\r\t ')
return response
def extract_chatml_response(text: str) -> str | None:
"""Extract response using ChatML markers"""
if "<|im_start|>assistant\n" in text:
parts = text.split("<|im_start|>assistant\n")
if len(parts) > 1:
response = parts[-1]
if "<|im_end|>" in response:
response = response.split("<|im_end|>")[0]
return clean_response(response)
return None
def extract_inst_response(text: str) -> str | None:
"""Extract response using inst/inst pattern (from your example)"""
# Look for [inst] or [/inst] patterns
inst_patterns = ["[inst]", "[/inst]"]
for pattern in inst_patterns:
if pattern in text.lower():
parts = text.lower().split(pattern)
if len(parts) > 1:
response = parts[-1]
return clean_response(response)
return None
def extract_after_prompt(text: str, prompt: str) -> str | None:
"""Extract response that comes directly after prompt"""
if text.startswith(prompt):
response = text[len(prompt):]
return clean_response(response)
return None
def extract_last_assistant_message(text: str) -> str:
"""Fallback: find the last assistant-like message"""
# Look for various assistant indicators
assistant_indicators = ["assistant", "[inst]", "[/inst]"]
text_lower = text.lower()
best_response = ""
for indicator in assistant_indicators:
if indicator in text_lower:
parts = text_lower.split(indicator)
if len(parts) > 1:
candidate = parts[-1]
if len(candidate) > len(best_response):
best_response = candidate
if best_response:
return clean_response(best_response)
# Final fallback: just clean the whole thing
return clean_response(text)
print(f"\n\n----- EXTRACTION DEBUG -----\n")
print(f"Generated text length: {len(generated_text)}")
print(f"Prompt length: {len(prompt)}")
print(f"Generated starts with prompt: {generated_text.startswith(prompt)}")
# Try extraction methods in order of reliability
response = None
# Method 1: ChatML extraction
response = extract_chatml_response(generated_text)
if response:
print("Used ChatML extraction")
else:
# Method 2: inst pattern extraction (from your example)
response = extract_inst_response(generated_text)
if response:
print("Used inst pattern extraction")
else:
# Method 3: Extract after prompt
response = extract_after_prompt(generated_text, prompt)
if response:
print("Used extract-after-prompt method")
else:
# Method 4: Fallback to last assistant message
response = extract_last_assistant_message(generated_text)
print("Used fallback extraction")
print(f"Extracted response: '{response}'")
print(f"Response length: {len(response)}")
print(f"----- END EXTRACTION DEBUG -----\n\n")
return response or ""
def generate(self, message: str, history: List[Tuple[str, str]] | None = None, system_prompt: str | None = None) -> str:
"""Generate response using universal chat template system"""
if history is None:
history = []
# Build messages
messages = self.build_messages(history, message, system_prompt)
# Format prompt
prompt = self.format_prompt(messages)
print(f"\n\n----- PROMPT -----\n{prompt}\n-----------------\n\n")
# Tokenize
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True)
# Move to model device
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
# Generate
generation_config = {
"max_new_tokens": 150,
"do_sample": True,
"temperature": 0.7,
"pad_token_id": self.tokenizer.eos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
**generation_config
)
# Decode
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\n\n----- GENERATED -----\n{generated_text}\n-------------------\n\n")
# Extract response
response = self.extract_response(prompt, generated_text)
return response
# Initialize with a tiny model for testing
# MODEL_NAME = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
SYSTEM_PROMPT = "You are a helpful assistant."
# Create global model instance
print("Creating model instance...")
chat_model = UniversalChatModel(MODEL_NAME)
def generate(message: str, history: List[Tuple[str, str]]) -> str:
"""Generate response using universal chat model"""
return chat_model.generate(message, history, SYSTEM_PROMPT)
if __name__ == "__main__":
# Quick test
print("Testing generation...")
reply = generate("What is 2+2?", [])
print(f"Final response: {reply}")