ai_resume / llama_chat_interface.py
Molchevsky's picture
Upload 4 files
f0a920b verified
import os
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig
)
from peft import PeftModel
import warnings
from datetime import datetime
import json
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
class LlamaChat:
def __init__(self, model_path, system_message=None, use_quantization=True, max_memory_gb=8):
"""
Initialize the chat interface with the fine-tuned Llama model
Args:
model_path: Path to the fine-tuned model directory
system_message: System message to use for conversations (persona/context)
use_quantization: Whether to use 4-bit quantization (recommended for 8GB GPU)
max_memory_gb: Maximum GPU memory to use
"""
self.model_path = model_path
self.use_quantization = use_quantization
self.max_memory_gb = max_memory_gb
# Default system message if none provided
self.system_message = system_message or (
"You are Alexander Molchevskyi β€” a senior software engineer with over 20 years "
"of professional experience across embedded, desktop, and server systems. "
"Skilled in C++, Rust, Python, AI infrastructure, compilers, WebAssembly, and "
"developer tooling. You answer interview questions clearly, professionally, and naturally."
)
print("πŸš€ Loading Llama Chat Interface...")
print(f"Model path: {model_path}")
print(f"System message: {self.system_message[:100]}{'...' if len(self.system_message) > 100 else ''}")
# Check CUDA availability
if torch.cuda.is_available():
print(f"βœ… CUDA available: {torch.cuda.get_device_name()}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
else:
print("⚠️ CUDA not available, using CPU (will be slow)")
self.tokenizer = None
self.model = None
self.conversation_history = []
self._load_model()
def _setup_quantization_config(self):
"""Setup 4-bit quantization config for memory efficiency"""
if not self.use_quantization:
return None
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
def _load_model(self):
"""Load the tokenizer and model"""
try:
print("πŸ“š Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
trust_remote_code=True,
padding_side="left" # For generation
)
# Add pad token if it doesn't exist
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
print("🧠 Loading base model...")
# Setup quantization if requested
quantization_config = self._setup_quantization_config()
# Check if this is a PEFT model (has adapter_config.json)
adapter_config_path = os.path.join(self.model_path, "adapter_config.json")
is_peft_model = os.path.exists(adapter_config_path)
if is_peft_model:
print("πŸ”§ Detected PEFT (LoRA) model, loading base model first...")
# Load adapter config to get base model name
with open(adapter_config_path, 'r') as f:
adapter_config = json.load(f)
base_model_name = adapter_config.get('base_model_name_or_path', 'llama-3.2-3b')
print(f"Base model: {base_model_name}")
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
use_cache=True, # Enable cache for inference
)
# Load PEFT model (LoRA adapter)
print("🎯 Loading LoRA adapter...")
self.model = PeftModel.from_pretrained(base_model, self.model_path)
else:
# Regular fine-tuned model (not PEFT)
print("πŸ“¦ Loading fine-tuned model...")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
quantization_config=quantization_config,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
use_cache=True, # Enable cache for inference
)
# Set model to evaluation mode
self.model.eval()
print("βœ… Model loaded successfully!")
# Print model info
if hasattr(self.model, 'print_trainable_parameters'):
self.model.print_trainable_parameters()
except Exception as e:
print(f"❌ Error loading model: {str(e)}")
raise
def _format_message(self, user_message):
"""Format user message with system context using Llama's chat template"""
return f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{self.system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
def generate_response(self, user_message, max_new_tokens=200, temperature=0.7,
top_p=0.9, repetition_penalty=1.1, do_sample=True):
"""
Generate a response to the user message
Args:
user_message: The user's input message
max_new_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (higher = more random)
top_p: Nucleus sampling parameter
repetition_penalty: Penalty for repeating tokens
do_sample: Whether to use sampling or greedy decoding
"""
try:
# Format the input
formatted_input = self._format_message(user_message)
# Tokenize input
inputs = self.tokenizer(
formatted_input,
return_tensors="pt",
truncation=True,
max_length=1024 # Increased to match training max_length
).to(self.model.device)
# Generate response
print("πŸ€” Thinking...")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
num_return_sequences=1,
)
# Decode the response
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response (after the last assistant header)
assistant_response = full_response.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
# Clean up any remaining tokens
assistant_response = assistant_response.replace("<|eot_id|>", "").strip()
return assistant_response
except Exception as e:
return f"❌ Error generating response: {str(e)}"
def chat_loop(self):
"""Main chat loop"""
print("\n" + "="*60)
print("πŸ¦™ LLAMA FINE-TUNED CHAT INTERFACE")
print("="*60)
print("Commands:")
print(" β€’ Type your message and press Enter")
print(" β€’ '/help' - Show this help")
print(" β€’ '/system' - View or change system message")
print(" β€’ '/settings' - Adjust generation settings")
print(" β€’ '/history' - Show conversation history")
print(" β€’ '/clear' - Clear conversation history")
print(" β€’ '/save' - Save conversation to file")
print(" β€’ '/quit' or '/exit' - Exit the chat")
print("="*60)
# Default generation settings
settings = {
'max_new_tokens': 200,
'temperature': 0.7,
'top_p': 0.9,
'repetition_penalty': 1.1,
'do_sample': True
}
while True:
try:
# Get user input
user_input = input("\nπŸ‘€ You: ").strip()
if not user_input:
continue
# Handle commands
if user_input.lower() in ['/quit', '/exit']:
print("πŸ‘‹ Goodbye!")
break
elif user_input.lower() == '/help':
self._show_help()
continue
elif user_input.lower() == '/system':
self._manage_system_message()
continue
elif user_input.lower() == '/settings':
settings = self._adjust_settings(settings)
continue
elif user_input.lower() == '/history':
self._show_history()
continue
elif user_input.lower() == '/clear':
self.conversation_history.clear()
print("🧹 Conversation history cleared!")
continue
elif user_input.lower() == '/save':
self._save_conversation()
continue
# Generate response
response = self.generate_response(user_input, **settings)
# Display response
print(f"\nπŸ¦™ Alexander: {response}")
# Save to history
self.conversation_history.append({
'timestamp': datetime.now().isoformat(),
'system': self.system_message,
'user': user_input,
'assistant': response
})
except KeyboardInterrupt:
print("\n\nπŸ‘‹ Chat interrupted. Goodbye!")
break
except Exception as e:
print(f"\n❌ Error: {str(e)}")
def _manage_system_message(self):
"""Allow user to view or change the system message"""
print("\nπŸ€– SYSTEM MESSAGE MANAGEMENT:")
print("Current system message:")
print("-" * 60)
print(self.system_message)
print("-" * 60)
choice = input("\nOptions: [v]iew, [c]hange, or [Enter] to go back: ").strip().lower()
if choice == 'c' or choice == 'change':
print("\nEnter new system message (or press Enter to keep current):")
new_system = input("> ").strip()
if new_system:
self.system_message = new_system
print("βœ… System message updated!")
print("Note: This will affect all future conversations.")
else:
print("System message unchanged.")
elif choice == 'v' or choice == 'view':
# Already displayed above
pass
def _show_help(self):
"""Show help information"""
print("\nπŸ“‹ HELP:")
print("This is a chat interface for your fine-tuned Llama model.")
print("The model has been trained with system messages to embody Alexander Molchevskyi's")
print("professional persona and expertise in software engineering.")
print("\nTips:")
print("β€’ Ask technical questions about software engineering, AI, or development")
print("β€’ The model maintains context of being Alexander throughout conversations")
print("β€’ Use /system to view or modify the professional persona")
print("β€’ Use /settings to adjust creativity (temperature) and response length")
print("β€’ Higher temperature = more creative but less consistent")
print("β€’ Lower temperature = more focused and consistent")
def _adjust_settings(self, current_settings):
"""Allow user to adjust generation settings"""
print("\nβš™οΈ GENERATION SETTINGS:")
print("Current settings:")
for key, value in current_settings.items():
print(f" {key}: {value}")
new_settings = current_settings.copy()
try:
# Max tokens
max_tokens = input(f"\nMax response length ({current_settings['max_new_tokens']}): ").strip()
if max_tokens:
new_settings['max_new_tokens'] = max(1, min(500, int(max_tokens)))
# Temperature
temp = input(f"Temperature 0.1-2.0 ({current_settings['temperature']}): ").strip()
if temp:
new_settings['temperature'] = max(0.1, min(2.0, float(temp)))
# Top-p
top_p = input(f"Top-p 0.1-1.0 ({current_settings['top_p']}): ").strip()
if top_p:
new_settings['top_p'] = max(0.1, min(1.0, float(top_p)))
# Repetition penalty
rep_penalty = input(f"Repetition penalty 1.0-2.0 ({current_settings['repetition_penalty']}): ").strip()
if rep_penalty:
new_settings['repetition_penalty'] = max(1.0, min(2.0, float(rep_penalty)))
print("βœ… Settings updated!")
return new_settings
except ValueError:
print("❌ Invalid input. Settings unchanged.")
return current_settings
def _show_history(self):
"""Show conversation history"""
if not self.conversation_history:
print("πŸ“ No conversation history yet.")
return
print(f"\nπŸ“œ CONVERSATION HISTORY ({len(self.conversation_history)} exchanges):")
print("-" * 50)
for i, exchange in enumerate(self.conversation_history[-5:], 1): # Show last 5
timestamp = exchange['timestamp'].split('T')[1].split('.')[0] # Just time
print(f"\n[{timestamp}]")
print(f"πŸ‘€ You: {exchange['user']}")
print(f"πŸ¦™ Alexander: {exchange['assistant'][:100]}{'...' if len(exchange['assistant']) > 100 else ''}")
if len(self.conversation_history) > 5:
print(f"\n... and {len(self.conversation_history) - 5} more exchanges")
def _save_conversation(self):
"""Save conversation to a JSON file"""
if not self.conversation_history:
print("πŸ“ No conversation to save.")
return
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"llama_chat_{timestamp}.json"
try:
with open(filename, 'w', encoding='utf-8') as f:
json.dump(self.conversation_history, f, indent=2, ensure_ascii=False)
print(f"πŸ’Ύ Conversation saved to: {filename}")
except Exception as e:
print(f"❌ Error saving conversation: {str(e)}")
def main():
"""Main function to start the chat interface"""
# Configuration
MODEL_PATH = "llama-3.2-3b-finetuned" # Path to your fine-tuned model
# Default system message (can be customized)
DEFAULT_SYSTEM_MESSAGE = (
"You are Alexander Molchevskyi β€” a senior software engineer with over 20 years "
"of professional experience across embedded, desktop, and server systems. "
"Skilled in C++, Rust, Python, AI infrastructure, compilers, WebAssembly, and "
"developer tooling. You answer interview questions clearly, professionally, and naturally."
)
# Check if model directory exists
if not os.path.exists(MODEL_PATH):
print(f"❌ Model directory not found: {MODEL_PATH}")
print("Please make sure you have run the fine-tuning script first.")
return
try:
# Initialize chat interface
chat = LlamaChat(
model_path=MODEL_PATH,
system_message=DEFAULT_SYSTEM_MESSAGE,
use_quantization=True, # Set to False if you have plenty of GPU memory
max_memory_gb=8
)
# Start chat loop
chat.chat_loop()
except Exception as e:
print(f"❌ Failed to initialize chat interface: {str(e)}")
print("\nTroubleshooting tips:")
print("1. Make sure the model was trained successfully")
print("2. Check that all required libraries are installed")
print("3. Ensure you have sufficient GPU memory")
print("4. Try setting use_quantization=True to reduce memory usage")
if __name__ == "__main__":
main()