|
|
from scripts.parsing_utils import load_yaml_file, get_roadmap_phases, get_project_rules |
|
|
import os |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
import yaml |
|
|
import logging |
|
|
import torch |
|
|
|
|
|
logging.basicConfig(level=logging.ERROR, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
class ProjectGuidanceChatbot: |
|
|
def __init__(self, roadmap_file, rules_file, config_file, code_templates_dir): |
|
|
self.roadmap_file = roadmap_file |
|
|
self.rules_file = rules_file |
|
|
self.config_file = config_file |
|
|
self.code_templates_dir = code_templates_dir |
|
|
|
|
|
self.roadmap_data = load_yaml_file(self.roadmap_file) |
|
|
self.rules_data = load_yaml_file(self.rules_file) |
|
|
self.config_data = load_yaml_file(self.config_file) |
|
|
|
|
|
self.phases = get_roadmap_phases(self.roadmap_data) |
|
|
self.rules = get_project_rules(self.rules_data) |
|
|
self.chatbot_config = self.config_data.get('chatbot', {}) if self.config_data else {} |
|
|
self.model_config = self.config_data.get('model_selection', {}) if self.config_data else {} |
|
|
self.response_config = self.config_data.get('response_generation', {}) if self.config_data else {} |
|
|
self.available_models_config = self.config_data.get('available_models', {}) if self.config_data else {} |
|
|
self.max_response_tokens = self.chatbot_config.get('max_response_tokens', 200) |
|
|
|
|
|
self.current_phase = None |
|
|
self.active_model_key = self.chatbot_config.get('default_llm_model_id') |
|
|
self.active_model_info = self.available_models_config.get(self.active_model_key) |
|
|
|
|
|
self.llm_model = None |
|
|
self.llm_tokenizer = None |
|
|
self.load_llm_model(self.active_model_info) |
|
|
|
|
|
self.update_mode_active = False |
|
|
|
|
|
def load_llm_model(self, model_info): |
|
|
"""Loads the LLM model and tokenizer based on model_info with 4-bit quantization.""" |
|
|
if not model_info: |
|
|
error_message = "Error: Model information not provided." |
|
|
logging.error(error_message) |
|
|
self.llm_model = None |
|
|
self.llm_tokenizer = None |
|
|
return |
|
|
|
|
|
model_id = model_info.get('model_id') |
|
|
model_name = model_info.get('name') |
|
|
if not model_id: |
|
|
error_message = f"Error: 'model_id' not found for model: {model_name}" |
|
|
logging.error(error_message) |
|
|
self.llm_model = None |
|
|
self.llm_tokenizer = None |
|
|
return |
|
|
|
|
|
print(f"Loading model: {model_name} ({model_id}) with 4-bit quantization...") |
|
|
try: |
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
) |
|
|
self.llm_tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
self.llm_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
device_map="auto", |
|
|
quantization_config=bnb_config |
|
|
) |
|
|
print(f"Model {model_name} loaded successfully with 4-bit quantization.") |
|
|
except Exception as e: |
|
|
error_message = f"Error loading model {model_name} ({model_id}) with 4-bit quantization: {e}" |
|
|
logging.exception(error_message) |
|
|
self.llm_model = None |
|
|
self.llm_tokenizer = None |
|
|
self.active_model_info = model_info |
|
|
|
|
|
def switch_llm_model(self, model_key): |
|
|
"""Switches the active LLM model based on the provided model key.""" |
|
|
if model_key in self.available_models_config: |
|
|
model_info = self.available_models_config[model_key] |
|
|
print(f"Switching LLM model to: {model_info.get('name')}") |
|
|
self.load_llm_model(model_info) |
|
|
self.active_model_key = model_key |
|
|
return f"Switched to model: {model_info.get('name')}" |
|
|
else: |
|
|
error_message = f"Error: Model key '{model_key}' not found in available models." |
|
|
logging.error(error_message) |
|
|
return error_message |
|
|
|
|
|
def enter_update_mode(self): |
|
|
"""Enters the chatbot's update mode.""" |
|
|
self.update_mode_active = True |
|
|
return "Entering update mode. Please enter configuration commands (or 'sagor is python/help' for commands)." |
|
|
|
|
|
def exit_update_mode(self): |
|
|
"""Exits the chatbot's update mode and reloads configuration.""" |
|
|
self.update_mode_active = False |
|
|
self.reload_config() |
|
|
return "Exiting update mode. Configuration reloaded." |
|
|
|
|
|
def reload_config(self): |
|
|
"""Reloads configuration files.""" |
|
|
print("Reloading configuration...") |
|
|
try: |
|
|
self.config_data = load_yaml_file(self.config_file) |
|
|
self.roadmap_data = load_yaml_file(self.roadmap_file) |
|
|
self.rules_data = load_yaml_file(self.rules_file) |
|
|
self.chatbot_config = self.config_data.get('chatbot', {}) if self.config_data else {} |
|
|
self.model_config = self.config_data.get('model_selection', {}) if self.config_data else {} |
|
|
self.response_config = self.config_data.get('response_generation', {}) if self.config_data else {} |
|
|
self.available_models_config = self.config_data.get('available_models', {}) if self.config_data else {} |
|
|
self.max_response_tokens = self.chatbot_config.get('max_response_tokens', 200) |
|
|
self.phases = get_roadmap_phases(self.roadmap_data) |
|
|
self.rules = get_project_rules(self.rules_data) |
|
|
print("Configuration reloaded.") |
|
|
except Exception as e: |
|
|
error_message = f"Error reloading configuration files: {e}" |
|
|
logging.exception(error_message) |
|
|
print(error_message) |
|
|
|
|
|
def get_chatbot_greeting(self): |
|
|
current_model_name = self.active_model_info.get('name', 'Unknown Model') if self.active_model_info else 'Unknown Model' |
|
|
return f"Hello! I am the {self.chatbot_config.get('name', 'Project Guidance Chatbot')}. Currently using **{current_model_name}** (4-bit quantized). Max response tokens: {self.max_response_tokens}. {self.chatbot_config.get('description', 'How can I help you with your project?')}" |
|
|
|
|
|
def generate_llm_response(self, user_query): |
|
|
"""Generates a response using the currently active LLM.""" |
|
|
if not self.llm_model or not self.llm_tokenizer: |
|
|
error_message = "LLM model not loaded. Please select a model." |
|
|
logging.error(error_message) |
|
|
return error_message |
|
|
try: |
|
|
inputs = self.llm_tokenizer(user_query, return_tensors="pt").to(self.llm_model.device) |
|
|
outputs = self.llm_model.generate(**inputs, max_length=self.max_response_tokens, num_beams=5, no_repeat_ngram_size=2, early_stopping=True) |
|
|
response = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return response |
|
|
except Exception as e: |
|
|
error_message = f"Error generating LLM response: {e}" |
|
|
logging.exception(error_message) |
|
|
return self.response_config.get('error_message', "Sorry, something went wrong.") |
|
|
|
|
|
def process_query(self, user_query): |
|
|
secret_command_prefix = "sagor is python/" |
|
|
|
|
|
if user_query.startswith(secret_command_prefix): |
|
|
command = user_query[len(secret_command_prefix):].strip() |
|
|
if command == "start": |
|
|
return self.enter_update_mode() |
|
|
elif command == "exit": |
|
|
return self.exit_update_mode() |
|
|
elif command == "help": |
|
|
return self.get_update_mode_help() |
|
|
elif self.update_mode_active: |
|
|
parts = command.split() |
|
|
if parts and parts[0] == "modify_rule": |
|
|
if len(parts) >= 4: |
|
|
category = parts[1] |
|
|
rule_key = parts[2] |
|
|
new_rule_text_parts = parts[3:] |
|
|
new_rule_text = " ".join(new_rule_text_parts) |
|
|
return self.modify_rule_in_config(category, rule_key, new_rule_text) |
|
|
else: |
|
|
return "Invalid 'modify_rule' command format. Usage: modify_rule <category> <rule_key> <new rule text>" |
|
|
elif parts and parts[0] == "add_phase": |
|
|
return "Command 'add_phase' is not yet implemented in this example." |
|
|
else: |
|
|
return f"Unknown update command: {command}. Type 'sagor is python/help' for commands." |
|
|
else: |
|
|
return f"Unknown command (not in update mode): {command}. To enter update mode, type 'sagor is python/start'." |
|
|
|
|
|
if self.update_mode_active: |
|
|
return "In update mode. Please enter a configuration command (or 'sagor is python/help' for commands)." |
|
|
|
|
|
if not self.phases: |
|
|
return "Error: Roadmap data not loaded correctly." |
|
|
if not self.rules: |
|
|
return "Error: Rules data not loaded correctly." |
|
|
|
|
|
query_lower = user_query.lower() |
|
|
|
|
|
if "roadmap" in query_lower: |
|
|
return self.get_roadmap_summary() |
|
|
if "rules" in query_lower: |
|
|
return self.get_rules_summary() |
|
|
if "switch model to" in query_lower: |
|
|
parts = user_query.lower().split("switch model to") |
|
|
if len(parts) > 1: |
|
|
model_name_or_key = parts[1].strip() |
|
|
for key, model_data in self.available_models_config.items(): |
|
|
if model_data['name'].lower() == model_name_or_key.lower() or key.lower() == model_name_or_key.lower(): |
|
|
switch_result = self.switch_llm_model(key) |
|
|
return switch_result + "\n" + self.get_chatbot_greeting() |
|
|
return f"Model '{model_name_or_key}' not found in available models." |
|
|
|
|
|
if self.current_phase: |
|
|
current_phase_data = self.phases.get(self.current_phase) |
|
|
if current_phase_data: |
|
|
phase_response = self.get_phase_guidance(current_phase_data, user_query) |
|
|
if phase_response: |
|
|
return phase_response |
|
|
|
|
|
for phase_key, phase_data in self.phases.items(): |
|
|
if phase_data['name'].lower() in query_lower: |
|
|
self.current_phase = phase_key |
|
|
return self.get_phase_introduction(phase_data) |
|
|
|
|
|
llm_response = self.generate_llm_response(user_query) |
|
|
if llm_response: |
|
|
return llm_response |
|
|
|
|
|
return self.response_config.get('default_instruction', "How can I help you with your project phases or general questions?") |
|
|
|
|
|
def get_update_mode_help(self): |
|
|
"""Provides help message for update mode commands.""" |
|
|
help_message = "Update Mode Commands:\n" |
|
|
help_message += "- `sagor is python/exit`: Exit update mode and reload configuration.\n" |
|
|
help_message += "- `sagor is python/modify_rule <category> <rule_key> <new rule text>`: Modify a rule in rules.yaml.\n" |
|
|
help_message += " Example: `sagor is python/modify_rule general rule_1 Prioritize open and responsible AI.`\n" |
|
|
help_message += "- `sagor is python/add_phase ...`: (Not yet implemented) Add a new phase to roadmap.yaml.\n" |
|
|
help_message += "- `sagor is python/help`: Show this help message.\n" |
|
|
help_message += "\nMake sure to use the correct syntax for commands. After exiting update mode, the chatbot will reload the configuration." |
|
|
return help_message |
|
|
|
|
|
def modify_rule_in_config(self, category, rule_key, new_rule_text): |
|
|
"""Modifies a rule in the rules.yaml configuration.""" |
|
|
if not self.rules_data or 'project_rules' not in self.rules_data: |
|
|
error_message = "Error: Rules data not loaded or invalid format." |
|
|
logging.error(error_message) |
|
|
return error_message |
|
|
if category not in self.rules_data['project_rules']: |
|
|
error_message = f"Error: Rule category '{category}' not found." |
|
|
logging.error(error_message) |
|
|
return error_message |
|
|
if rule_key not in self.rules_data['project_rules'][category]: |
|
|
error_message = f"Error: Rule key '{rule_key}' not found in category '{category}'." |
|
|
logging.error(error_message) |
|
|
return error_message |
|
|
|
|
|
self.rules_data['project_rules'][category][rule_key] = new_rule_text |
|
|
|
|
|
try: |
|
|
with open(self.rules_file, 'w') as f: |
|
|
yaml.dump(self.rules_data, f, indent=2) |
|
|
self.reload_config() |
|
|
return f"Rule '{rule_key}' in category '{category}' updated to: '{new_rule_text}'. Configuration reloaded." |
|
|
except Exception as e: |
|
|
error_message = f"Error saving changes to {self.rules_file}: {e}" |
|
|
logging.exception(error_message) |
|
|
return error_message |
|
|
|
|
|
def get_roadmap_summary(self): |
|
|
summary = "Project Roadmap:\n" |
|
|
for phase_key, phase_data in self.phases.items(): |
|
|
summary += f"- **Phase: {phase_data['name']}**\n" |
|
|
summary += f" Description: {phase_data['description']}\n" |
|
|
summary += f" Milestones: {', '.join(phase_data['milestones'])}\n" |
|
|
return summary |
|
|
|
|
|
def get_rules_summary(self): |
|
|
summary = "Project Rules:\n" |
|
|
for rule_category, rules_list in self.rules.items(): |
|
|
summary += f"**{rule_category.capitalize()} Rules:**\n" |
|
|
for rule_key, rule_text in rules_list.items(): |
|
|
summary += f"- {rule_text}\n" |
|
|
return summary |
|
|
|
|
|
def get_phase_introduction(self, phase_data): |
|
|
return f"Okay, let's focus on **Phase: {phase_data['name']}**. \nDescription: {phase_data['description']}. \nKey milestones are: {', '.join(phase_data['milestones'])}. \nWhat would you like to know or do in this phase?" |
|
|
|
|
|
def get_phase_guidance(self, phase_data, user_query): |
|
|
query_lower = user_query.lower() |
|
|
|
|
|
if "milestones" in query_lower: |
|
|
return "The milestones for this phase are: " + ", ".join(phase_data['milestones']) |
|
|
if "actions" in query_lower or "how to" in query_lower: |
|
|
if 'actions' in phase_data: |
|
|
return "Recommended actions for this phase: " + ", ".join(phase_data['actions']) |
|
|
else: |
|
|
return "No specific actions are listed for this phase in the roadmap." |
|
|
if "code" in query_lower or "script" in query_lower: |
|
|
if 'code_generation_hint' in phase_data: |
|
|
template_filename_prefix = phase_data['name'].lower().replace(" ", "_") |
|
|
template_filepath = os.path.join(self.code_templates_dir, f"{template_filename_prefix}_template.py.txt") |
|
|
if os.path.exists(template_filepath): |
|
|
code_snippet = self.generate_code_snippet(template_filepath, phase_data) |
|
|
return "Here's a starting code snippet for this phase:\n\n```python\n" + code_snippet + "\n```\n\nRemember to adapt it to your specific needs." |
|
|
else: |
|
|
return f"A code template for this phase ({phase_data['name']}) is not yet available. However, the hint is: {phase_data['code_generation_hint']}" |
|
|
else: |
|
|
return "No code generation hint is available for this phase." |
|
|
|
|
|
return f"For phase '{phase_data['name']}', remember the description: {phase_data['description']}. Consider the milestones and actions. What specific aspect are you interested in?" |
|
|
|
|
|
def generate_code_snippet(self, template_filepath, phase_data): |
|
|
"""Generates code snippet from a template file. (Simple template filling example)""" |
|
|
try: |
|
|
with open(template_filepath, 'r') as f: |
|
|
template_content = f.read() |
|
|
|
|
|
code_snippet = template_content.replace("{{phase_name}}", phase_data['name']) |
|
|
return code_snippet |
|
|
except FileNotFoundError: |
|
|
return f"Error: Code template file not found at {template_filepath}" |
|
|
except Exception as e: |
|
|
return f"Error generating code snippet: {e}" |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
chatbot = ProjectGuidanceChatbot( |
|
|
roadmap_file="roadmap.yaml", |
|
|
rules_file="rules.yaml", |
|
|
config_file="configs/chatbot_config.yaml", |
|
|
code_templates_dir="scripts/code_templates" |
|
|
) |
|
|
print(chatbot.get_chatbot_greeting()) |
|
|
|
|
|
while True: |
|
|
user_input = input("You: ") |
|
|
if user_input.lower() == "exit": |
|
|
break |
|
|
response = chatbot.process_query(user_input) |
|
|
print("Chatbot:", response) |