from scripts.parsing_utils import load_yaml_file, get_roadmap_phases, get_project_rules import os from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import yaml import logging import torch # ADD THIS LINE - Import torch logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s') class ProjectGuidanceChatbot: def __init__(self, roadmap_file, rules_file, config_file, code_templates_dir): self.roadmap_file = roadmap_file self.rules_file = rules_file self.config_file = config_file self.code_templates_dir = code_templates_dir self.roadmap_data = load_yaml_file(self.roadmap_file) self.rules_data = load_yaml_file(self.rules_file) self.config_data = load_yaml_file(self.config_file) self.phases = get_roadmap_phases(self.roadmap_data) self.rules = get_project_rules(self.rules_data) self.chatbot_config = self.config_data.get('chatbot', {}) if self.config_data else {} self.model_config = self.config_data.get('model_selection', {}) if self.config_data else {} self.response_config = self.config_data.get('response_generation', {}) if self.config_data else {} self.available_models_config = self.config_data.get('available_models', {}) if self.config_data else {} self.max_response_tokens = self.chatbot_config.get('max_response_tokens', 200) self.current_phase = None self.active_model_key = self.chatbot_config.get('default_llm_model_id') self.active_model_info = self.available_models_config.get(self.active_model_key) self.llm_model = None self.llm_tokenizer = None self.load_llm_model(self.active_model_info) self.update_mode_active = False def load_llm_model(self, model_info): """Loads the LLM model and tokenizer based on model_info with 4-bit quantization.""" if not model_info: error_message = "Error: Model information not provided." logging.error(error_message) self.llm_model = None self.llm_tokenizer = None return model_id = model_info.get('model_id') model_name = model_info.get('name') if not model_id: error_message = f"Error: 'model_id' not found for model: {model_name}" logging.error(error_message) self.llm_model = None self.llm_tokenizer = None return print(f"Loading model: {model_name} ({model_id}) with 4-bit quantization...") # Indicate quantization try: bnb_config = BitsAndBytesConfig( # Configure 4-bit quantization load_in_4bit=True, bnb_4bit_quant_type="nf4", # "nf4" is recommended for Llama models bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if bfloat16 not supported ) self.llm_tokenizer = AutoTokenizer.from_pretrained(model_id) self.llm_model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", quantization_config=bnb_config # Apply quantization config ) print(f"Model {model_name} loaded successfully with 4-bit quantization.") # Indicate quantization success except Exception as e: error_message = f"Error loading model {model_name} ({model_id}) with 4-bit quantization: {e}" logging.exception(error_message) self.llm_model = None self.llm_tokenizer = None self.active_model_info = model_info def switch_llm_model(self, model_key): """Switches the active LLM model based on the provided model key.""" if model_key in self.available_models_config: model_info = self.available_models_config[model_key] print(f"Switching LLM model to: {model_info.get('name')}") self.load_llm_model(model_info) self.active_model_key = model_key return f"Switched to model: {model_info.get('name')}" else: error_message = f"Error: Model key '{model_key}' not found in available models." logging.error(error_message) return error_message def enter_update_mode(self): """Enters the chatbot's update mode.""" self.update_mode_active = True return "Entering update mode. Please enter configuration commands (or 'sagor is python/help' for commands)." def exit_update_mode(self): """Exits the chatbot's update mode and reloads configuration.""" self.update_mode_active = False self.reload_config() return "Exiting update mode. Configuration reloaded." def reload_config(self): """Reloads configuration files.""" print("Reloading configuration...") try: self.config_data = load_yaml_file(self.config_file) self.roadmap_data = load_yaml_file(self.roadmap_file) self.rules_data = load_yaml_file(self.rules_file) self.chatbot_config = self.config_data.get('chatbot', {}) if self.config_data else {} self.model_config = self.config_data.get('model_selection', {}) if self.config_data else {} self.response_config = self.config_data.get('response_generation', {}) if self.config_data else {} self.available_models_config = self.config_data.get('available_models', {}) if self.config_data else {} self.max_response_tokens = self.chatbot_config.get('max_response_tokens', 200) self.phases = get_roadmap_phases(self.roadmap_data) self.rules = get_project_rules(self.rules_data) print("Configuration reloaded.") except Exception as e: error_message = f"Error reloading configuration files: {e}" logging.exception(error_message) print(error_message) def get_chatbot_greeting(self): current_model_name = self.active_model_info.get('name', 'Unknown Model') if self.active_model_info else 'Unknown Model' return f"Hello! I am the {self.chatbot_config.get('name', 'Project Guidance Chatbot')}. Currently using **{current_model_name}** (4-bit quantized). Max response tokens: {self.max_response_tokens}. {self.chatbot_config.get('description', 'How can I help you with your project?')}" # Indicate quantization in greeting def generate_llm_response(self, user_query): """Generates a response using the currently active LLM.""" if not self.llm_model or not self.llm_tokenizer: error_message = "LLM model not loaded. Please select a model." logging.error(error_message) return error_message try: inputs = self.llm_tokenizer(user_query, return_tensors="pt").to(self.llm_model.device) outputs = self.llm_model.generate(**inputs, max_length=self.max_response_tokens, num_beams=5, no_repeat_ngram_size=2, early_stopping=True) response = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: error_message = f"Error generating LLM response: {e}" logging.exception(error_message) return self.response_config.get('error_message', "Sorry, something went wrong.") def process_query(self, user_query): secret_command_prefix = "sagor is python/" if user_query.startswith(secret_command_prefix): command = user_query[len(secret_command_prefix):].strip() if command == "start": return self.enter_update_mode() elif command == "exit": return self.exit_update_mode() elif command == "help": return self.get_update_mode_help() elif self.update_mode_active: parts = command.split() if parts and parts[0] == "modify_rule": if len(parts) >= 4: category = parts[1] rule_key = parts[2] new_rule_text_parts = parts[3:] new_rule_text = " ".join(new_rule_text_parts) return self.modify_rule_in_config(category, rule_key, new_rule_text) else: return "Invalid 'modify_rule' command format. Usage: modify_rule " elif parts and parts[0] == "add_phase": return "Command 'add_phase' is not yet implemented in this example." else: return f"Unknown update command: {command}. Type 'sagor is python/help' for commands." else: return f"Unknown command (not in update mode): {command}. To enter update mode, type 'sagor is python/start'." if self.update_mode_active: return "In update mode. Please enter a configuration command (or 'sagor is python/help' for commands)." if not self.phases: return "Error: Roadmap data not loaded correctly." if not self.rules: return "Error: Rules data not loaded correctly." query_lower = user_query.lower() if "roadmap" in query_lower: return self.get_roadmap_summary() if "rules" in query_lower: return self.get_rules_summary() if "switch model to" in query_lower: parts = user_query.lower().split("switch model to") if len(parts) > 1: model_name_or_key = parts[1].strip() for key, model_data in self.available_models_config.items(): if model_data['name'].lower() == model_name_or_key.lower() or key.lower() == model_name_or_key.lower(): switch_result = self.switch_llm_model(key) return switch_result + "\n" + self.get_chatbot_greeting() return f"Model '{model_name_or_key}' not found in available models." if self.current_phase: current_phase_data = self.phases.get(self.current_phase) if current_phase_data: phase_response = self.get_phase_guidance(current_phase_data, user_query) if phase_response: return phase_response for phase_key, phase_data in self.phases.items(): if phase_data['name'].lower() in query_lower: self.current_phase = phase_key return self.get_phase_introduction(phase_data) llm_response = self.generate_llm_response(user_query) if llm_response: return llm_response return self.response_config.get('default_instruction', "How can I help you with your project phases or general questions?") def get_update_mode_help(self): """Provides help message for update mode commands.""" help_message = "Update Mode Commands:\n" help_message += "- `sagor is python/exit`: Exit update mode and reload configuration.\n" help_message += "- `sagor is python/modify_rule `: Modify a rule in rules.yaml.\n" help_message += " Example: `sagor is python/modify_rule general rule_1 Prioritize open and responsible AI.`\n" help_message += "- `sagor is python/add_phase ...`: (Not yet implemented) Add a new phase to roadmap.yaml.\n" help_message += "- `sagor is python/help`: Show this help message.\n" help_message += "\nMake sure to use the correct syntax for commands. After exiting update mode, the chatbot will reload the configuration." return help_message def modify_rule_in_config(self, category, rule_key, new_rule_text): """Modifies a rule in the rules.yaml configuration.""" if not self.rules_data or 'project_rules' not in self.rules_data: error_message = "Error: Rules data not loaded or invalid format." logging.error(error_message) return error_message if category not in self.rules_data['project_rules']: error_message = f"Error: Rule category '{category}' not found." logging.error(error_message) return error_message if rule_key not in self.rules_data['project_rules'][category]: error_message = f"Error: Rule key '{rule_key}' not found in category '{category}'." logging.error(error_message) return error_message self.rules_data['project_rules'][category][rule_key] = new_rule_text try: with open(self.rules_file, 'w') as f: yaml.dump(self.rules_data, f, indent=2) self.reload_config() return f"Rule '{rule_key}' in category '{category}' updated to: '{new_rule_text}'. Configuration reloaded." except Exception as e: error_message = f"Error saving changes to {self.rules_file}: {e}" logging.exception(error_message) return error_message def get_roadmap_summary(self): summary = "Project Roadmap:\n" for phase_key, phase_data in self.phases.items(): summary += f"- **Phase: {phase_data['name']}**\n" summary += f" Description: {phase_data['description']}\n" summary += f" Milestones: {', '.join(phase_data['milestones'])}\n" return summary def get_rules_summary(self): summary = "Project Rules:\n" for rule_category, rules_list in self.rules.items(): summary += f"**{rule_category.capitalize()} Rules:**\n" for rule_key, rule_text in rules_list.items(): summary += f"- {rule_text}\n" return summary def get_phase_introduction(self, phase_data): return f"Okay, let's focus on **Phase: {phase_data['name']}**. \nDescription: {phase_data['description']}. \nKey milestones are: {', '.join(phase_data['milestones'])}. \nWhat would you like to know or do in this phase?" def get_phase_guidance(self, phase_data, user_query): query_lower = user_query.lower() if "milestones" in query_lower: return "The milestones for this phase are: " + ", ".join(phase_data['milestones']) if "actions" in query_lower or "how to" in query_lower: if 'actions' in phase_data: return "Recommended actions for this phase: " + ", ".join(phase_data['actions']) else: return "No specific actions are listed for this phase in the roadmap." if "code" in query_lower or "script" in query_lower: if 'code_generation_hint' in phase_data: template_filename_prefix = phase_data['name'].lower().replace(" ", "_") template_filepath = os.path.join(self.code_templates_dir, f"{template_filename_prefix}_template.py.txt") if os.path.exists(template_filepath): code_snippet = self.generate_code_snippet(template_filepath, phase_data) return "Here's a starting code snippet for this phase:\n\n```python\n" + code_snippet + "\n```\n\nRemember to adapt it to your specific needs." else: return f"A code template for this phase ({phase_data['name']}) is not yet available. However, the hint is: {phase_data['code_generation_hint']}" else: return "No code generation hint is available for this phase." return f"For phase '{phase_data['name']}', remember the description: {phase_data['description']}. Consider the milestones and actions. What specific aspect are you interested in?" def generate_code_snippet(self, template_filepath, phase_data): """Generates code snippet from a template file. (Simple template filling example)""" try: with open(template_filepath, 'r') as f: template_content = f.read() code_snippet = template_content.replace("{{phase_name}}", phase_data['name']) return code_snippet except FileNotFoundError: return f"Error: Code template file not found at {template_filepath}" except Exception as e: return f"Error generating code snippet: {e}" # Example usage (for testing - remove or adjust for app.py) if __name__ == '__main__': chatbot = ProjectGuidanceChatbot( roadmap_file="roadmap.yaml", rules_file="rules.yaml", config_file="configs/chatbot_config.yaml", code_templates_dir="scripts/code_templates" ) print(chatbot.get_chatbot_greeting()) while True: user_input = input("You: ") if user_input.lower() == "exit": break response = chatbot.process_query(user_input) print("Chatbot:", response)