| | import os |
| | import torch |
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoModelForCausalLM, |
| | BitsAndBytesConfig |
| | ) |
| | from peft import PeftModel |
| | import warnings |
| | from datetime import datetime |
| | import json |
| |
|
| | |
| | warnings.filterwarnings("ignore", category=FutureWarning) |
| | warnings.filterwarnings("ignore", category=UserWarning) |
| | os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
| |
|
| | class LlamaChat: |
| | def __init__(self, model_path, system_message=None, use_quantization=True, max_memory_gb=8): |
| | """ |
| | Initialize the chat interface with the fine-tuned Llama model |
| | |
| | Args: |
| | model_path: Path to the fine-tuned model directory |
| | system_message: System message to use for conversations (persona/context) |
| | use_quantization: Whether to use 4-bit quantization (recommended for 8GB GPU) |
| | max_memory_gb: Maximum GPU memory to use |
| | """ |
| | self.model_path = model_path |
| | self.use_quantization = use_quantization |
| | self.max_memory_gb = max_memory_gb |
| | |
| | |
| | self.system_message = system_message or ( |
| | "You are Alexander Molchevskyi β a senior software engineer with over 20 years " |
| | "of professional experience across embedded, desktop, and server systems. " |
| | "Skilled in C++, Rust, Python, AI infrastructure, compilers, WebAssembly, and " |
| | "developer tooling. You answer interview questions clearly, professionally, and naturally." |
| | ) |
| | |
| | print("π Loading Llama Chat Interface...") |
| | print(f"Model path: {model_path}") |
| | print(f"System message: {self.system_message[:100]}{'...' if len(self.system_message) > 100 else ''}") |
| | |
| | |
| | if torch.cuda.is_available(): |
| | print(f"β
CUDA available: {torch.cuda.get_device_name()}") |
| | print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") |
| | else: |
| | print("β οΈ CUDA not available, using CPU (will be slow)") |
| | |
| | self.tokenizer = None |
| | self.model = None |
| | self.conversation_history = [] |
| | |
| | self._load_model() |
| | |
| | def _setup_quantization_config(self): |
| | """Setup 4-bit quantization config for memory efficiency""" |
| | if not self.use_quantization: |
| | return None |
| | |
| | return BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=torch.bfloat16, |
| | ) |
| | |
| | def _load_model(self): |
| | """Load the tokenizer and model""" |
| | try: |
| | print("π Loading tokenizer...") |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | self.model_path, |
| | trust_remote_code=True, |
| | padding_side="left" |
| | ) |
| | |
| | |
| | if self.tokenizer.pad_token is None: |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
| | |
| | print("π§ Loading base model...") |
| | |
| | |
| | quantization_config = self._setup_quantization_config() |
| | |
| | |
| | adapter_config_path = os.path.join(self.model_path, "adapter_config.json") |
| | is_peft_model = os.path.exists(adapter_config_path) |
| | |
| | if is_peft_model: |
| | print("π§ Detected PEFT (LoRA) model, loading base model first...") |
| | |
| | |
| | with open(adapter_config_path, 'r') as f: |
| | adapter_config = json.load(f) |
| | |
| | base_model_name = adapter_config.get('base_model_name_or_path', 'llama-3.2-3b') |
| | print(f"Base model: {base_model_name}") |
| | |
| | |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | base_model_name, |
| | quantization_config=quantization_config, |
| | device_map="auto", |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True, |
| | use_cache=True, |
| | ) |
| | |
| | |
| | print("π― Loading LoRA adapter...") |
| | self.model = PeftModel.from_pretrained(base_model, self.model_path) |
| | |
| | else: |
| | |
| | print("π¦ Loading fine-tuned model...") |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | self.model_path, |
| | quantization_config=quantization_config, |
| | device_map="auto", |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True, |
| | use_cache=True, |
| | ) |
| | |
| | |
| | self.model.eval() |
| | print("β
Model loaded successfully!") |
| | |
| | |
| | if hasattr(self.model, 'print_trainable_parameters'): |
| | self.model.print_trainable_parameters() |
| | |
| | except Exception as e: |
| | print(f"β Error loading model: {str(e)}") |
| | raise |
| | |
| | def _format_message(self, user_message): |
| | """Format user message with system context using Llama's chat template""" |
| | return f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{self.system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user_message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" |
| | |
| | def generate_response(self, user_message, max_new_tokens=200, temperature=0.7, |
| | top_p=0.9, repetition_penalty=1.1, do_sample=True): |
| | """ |
| | Generate a response to the user message |
| | |
| | Args: |
| | user_message: The user's input message |
| | max_new_tokens: Maximum number of tokens to generate |
| | temperature: Sampling temperature (higher = more random) |
| | top_p: Nucleus sampling parameter |
| | repetition_penalty: Penalty for repeating tokens |
| | do_sample: Whether to use sampling or greedy decoding |
| | """ |
| | try: |
| | |
| | formatted_input = self._format_message(user_message) |
| | |
| | |
| | inputs = self.tokenizer( |
| | formatted_input, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=1024 |
| | ).to(self.model.device) |
| | |
| | |
| | print("π€ Thinking...") |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_new_tokens=max_new_tokens, |
| | temperature=temperature, |
| | top_p=top_p, |
| | do_sample=do_sample, |
| | repetition_penalty=repetition_penalty, |
| | pad_token_id=self.tokenizer.eos_token_id, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | num_return_sequences=1, |
| | ) |
| | |
| | |
| | full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | |
| | assistant_response = full_response.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip() |
| | |
| | |
| | assistant_response = assistant_response.replace("<|eot_id|>", "").strip() |
| | |
| | return assistant_response |
| | |
| | except Exception as e: |
| | return f"β Error generating response: {str(e)}" |
| | |
| | def chat_loop(self): |
| | """Main chat loop""" |
| | print("\n" + "="*60) |
| | print("π¦ LLAMA FINE-TUNED CHAT INTERFACE") |
| | print("="*60) |
| | print("Commands:") |
| | print(" β’ Type your message and press Enter") |
| | print(" β’ '/help' - Show this help") |
| | print(" β’ '/system' - View or change system message") |
| | print(" β’ '/settings' - Adjust generation settings") |
| | print(" β’ '/history' - Show conversation history") |
| | print(" β’ '/clear' - Clear conversation history") |
| | print(" β’ '/save' - Save conversation to file") |
| | print(" β’ '/quit' or '/exit' - Exit the chat") |
| | print("="*60) |
| | |
| | |
| | settings = { |
| | 'max_new_tokens': 200, |
| | 'temperature': 0.7, |
| | 'top_p': 0.9, |
| | 'repetition_penalty': 1.1, |
| | 'do_sample': True |
| | } |
| | |
| | while True: |
| | try: |
| | |
| | user_input = input("\nπ€ You: ").strip() |
| | |
| | if not user_input: |
| | continue |
| | |
| | |
| | if user_input.lower() in ['/quit', '/exit']: |
| | print("π Goodbye!") |
| | break |
| | |
| | elif user_input.lower() == '/help': |
| | self._show_help() |
| | continue |
| | |
| | elif user_input.lower() == '/system': |
| | self._manage_system_message() |
| | continue |
| | |
| | elif user_input.lower() == '/settings': |
| | settings = self._adjust_settings(settings) |
| | continue |
| | |
| | elif user_input.lower() == '/history': |
| | self._show_history() |
| | continue |
| | |
| | elif user_input.lower() == '/clear': |
| | self.conversation_history.clear() |
| | print("π§Ή Conversation history cleared!") |
| | continue |
| | |
| | elif user_input.lower() == '/save': |
| | self._save_conversation() |
| | continue |
| | |
| | |
| | response = self.generate_response(user_input, **settings) |
| | |
| | |
| | print(f"\nπ¦ Alexander: {response}") |
| | |
| | |
| | self.conversation_history.append({ |
| | 'timestamp': datetime.now().isoformat(), |
| | 'system': self.system_message, |
| | 'user': user_input, |
| | 'assistant': response |
| | }) |
| | |
| | except KeyboardInterrupt: |
| | print("\n\nπ Chat interrupted. Goodbye!") |
| | break |
| | except Exception as e: |
| | print(f"\nβ Error: {str(e)}") |
| | |
| | def _manage_system_message(self): |
| | """Allow user to view or change the system message""" |
| | print("\nπ€ SYSTEM MESSAGE MANAGEMENT:") |
| | print("Current system message:") |
| | print("-" * 60) |
| | print(self.system_message) |
| | print("-" * 60) |
| | |
| | choice = input("\nOptions: [v]iew, [c]hange, or [Enter] to go back: ").strip().lower() |
| | |
| | if choice == 'c' or choice == 'change': |
| | print("\nEnter new system message (or press Enter to keep current):") |
| | new_system = input("> ").strip() |
| | |
| | if new_system: |
| | self.system_message = new_system |
| | print("β
System message updated!") |
| | print("Note: This will affect all future conversations.") |
| | else: |
| | print("System message unchanged.") |
| | |
| | elif choice == 'v' or choice == 'view': |
| | |
| | pass |
| | def _show_help(self): |
| | """Show help information""" |
| | print("\nπ HELP:") |
| | print("This is a chat interface for your fine-tuned Llama model.") |
| | print("The model has been trained with system messages to embody Alexander Molchevskyi's") |
| | print("professional persona and expertise in software engineering.") |
| | print("\nTips:") |
| | print("β’ Ask technical questions about software engineering, AI, or development") |
| | print("β’ The model maintains context of being Alexander throughout conversations") |
| | print("β’ Use /system to view or modify the professional persona") |
| | print("β’ Use /settings to adjust creativity (temperature) and response length") |
| | print("β’ Higher temperature = more creative but less consistent") |
| | print("β’ Lower temperature = more focused and consistent") |
| | |
| | def _adjust_settings(self, current_settings): |
| | """Allow user to adjust generation settings""" |
| | print("\nβοΈ GENERATION SETTINGS:") |
| | print("Current settings:") |
| | for key, value in current_settings.items(): |
| | print(f" {key}: {value}") |
| | |
| | new_settings = current_settings.copy() |
| | |
| | try: |
| | |
| | max_tokens = input(f"\nMax response length ({current_settings['max_new_tokens']}): ").strip() |
| | if max_tokens: |
| | new_settings['max_new_tokens'] = max(1, min(500, int(max_tokens))) |
| | |
| | |
| | temp = input(f"Temperature 0.1-2.0 ({current_settings['temperature']}): ").strip() |
| | if temp: |
| | new_settings['temperature'] = max(0.1, min(2.0, float(temp))) |
| | |
| | |
| | top_p = input(f"Top-p 0.1-1.0 ({current_settings['top_p']}): ").strip() |
| | if top_p: |
| | new_settings['top_p'] = max(0.1, min(1.0, float(top_p))) |
| | |
| | |
| | rep_penalty = input(f"Repetition penalty 1.0-2.0 ({current_settings['repetition_penalty']}): ").strip() |
| | if rep_penalty: |
| | new_settings['repetition_penalty'] = max(1.0, min(2.0, float(rep_penalty))) |
| | |
| | print("β
Settings updated!") |
| | return new_settings |
| | |
| | except ValueError: |
| | print("β Invalid input. Settings unchanged.") |
| | return current_settings |
| | |
| | def _show_history(self): |
| | """Show conversation history""" |
| | if not self.conversation_history: |
| | print("π No conversation history yet.") |
| | return |
| | |
| | print(f"\nπ CONVERSATION HISTORY ({len(self.conversation_history)} exchanges):") |
| | print("-" * 50) |
| | |
| | for i, exchange in enumerate(self.conversation_history[-5:], 1): |
| | timestamp = exchange['timestamp'].split('T')[1].split('.')[0] |
| | print(f"\n[{timestamp}]") |
| | print(f"π€ You: {exchange['user']}") |
| | print(f"π¦ Alexander: {exchange['assistant'][:100]}{'...' if len(exchange['assistant']) > 100 else ''}") |
| | |
| | if len(self.conversation_history) > 5: |
| | print(f"\n... and {len(self.conversation_history) - 5} more exchanges") |
| | |
| | def _save_conversation(self): |
| | """Save conversation to a JSON file""" |
| | if not self.conversation_history: |
| | print("π No conversation to save.") |
| | return |
| | |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | filename = f"llama_chat_{timestamp}.json" |
| | |
| | try: |
| | with open(filename, 'w', encoding='utf-8') as f: |
| | json.dump(self.conversation_history, f, indent=2, ensure_ascii=False) |
| | print(f"πΎ Conversation saved to: {filename}") |
| | except Exception as e: |
| | print(f"β Error saving conversation: {str(e)}") |
| |
|
| | def main(): |
| | """Main function to start the chat interface""" |
| | |
| | MODEL_PATH = "llama-3.2-3b-finetuned" |
| | |
| | |
| | DEFAULT_SYSTEM_MESSAGE = ( |
| | "You are Alexander Molchevskyi β a senior software engineer with over 20 years " |
| | "of professional experience across embedded, desktop, and server systems. " |
| | "Skilled in C++, Rust, Python, AI infrastructure, compilers, WebAssembly, and " |
| | "developer tooling. You answer interview questions clearly, professionally, and naturally." |
| | ) |
| | |
| | |
| | if not os.path.exists(MODEL_PATH): |
| | print(f"β Model directory not found: {MODEL_PATH}") |
| | print("Please make sure you have run the fine-tuning script first.") |
| | return |
| | |
| | try: |
| | |
| | chat = LlamaChat( |
| | model_path=MODEL_PATH, |
| | system_message=DEFAULT_SYSTEM_MESSAGE, |
| | use_quantization=True, |
| | max_memory_gb=8 |
| | ) |
| | |
| | |
| | chat.chat_loop() |
| | |
| | except Exception as e: |
| | print(f"β Failed to initialize chat interface: {str(e)}") |
| | print("\nTroubleshooting tips:") |
| | print("1. Make sure the model was trained successfully") |
| | print("2. Check that all required libraries are installed") |
| | print("3. Ensure you have sufficient GPU memory") |
| | print("4. Try setting use_quantization=True to reduce memory usage") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|