""" Input Analysis Agent Analyzes user prompts to understand task type and intent. """ import re from typing import Dict, List, Any from langchain_community.llms import OpenAI from langchain.prompts import PromptTemplate import os class InputAnalysisAgent: def __init__(self): self.llm = OpenAI( openai_api_key=os.getenv("OPENAI_API_KEY"), temperature=0.1, model_name="gpt-3.5-turbo-instruct" ) self.analysis_prompt = PromptTemplate( input_variables=["user_prompt"], template=""" Analyze the following user prompt and extract key information: User Prompt: "{user_prompt}" Please provide: 1. Task Type (choose **exactly one** from the list below): [text_generation, question_answering, code_generation, summarization, translation, reasoning, creative_writing, data_analysis] 2. Complexity Level (simple, moderate, complex) 3. Key Entities (extract important nouns, concepts, or variables) 4. Intent (what the user wants to achieve) 5. Domain (e.g., general, technical, academic, creative, business) Format your response as: TASK_TYPE: [task_type] COMPLEXITY: [complexity_level] ENTITIES: [comma-separated list] INTENT: [brief description] DOMAIN: [domain] """ ) def analyze_prompt(self, user_prompt: str) -> Dict[str, Any]: prompt_text = self.analysis_prompt.format(user_prompt=user_prompt) try: response = self.llm.invoke(prompt_text) print(f"\n[DEBUG] LLM raw response: {response}\n") analysis_result = response.strip() # Always a string with instruct models parsed_result = self._parse_analysis_result(analysis_result) parsed_result["original_prompt"] = user_prompt parsed_result["word_count"] = len(user_prompt.split()) parsed_result["char_count"] = len(user_prompt) return parsed_result except Exception as e: print(f"[ERROR] Falling back due to exception: {e}") return self._fallback_analysis(user_prompt) def _parse_analysis_result(self, analysis_result: str) -> Dict[str, Any]: result = {} lines = analysis_result.strip().split('\n') for line in lines: if ':' in line: key, value = line.split(':', 1) key = key.strip().lower() value = value.strip() if key == 'task_type': result['task_type'] = value elif key == 'complexity': result['complexity'] = value elif key == 'entities': result['entities'] = [entity.strip() for entity in value.split(',') if entity.strip()] elif key == 'intent': result['intent'] = value elif key == 'domain': result['domain'] = value return result def _fallback_analysis(self, user_prompt: str) -> Dict[str, Any]: task_keywords = { 'question_answering': ['what', 'how', 'why', 'when', 'where', 'who', 'explain', 'answer'], 'code_generation': ['code', 'function', 'program', 'script', 'algorithm', 'implement'], 'summarization': ['summarize', 'summary', 'brief', 'overview', 'key points'], 'translation': ['translate', 'translation', 'convert to', 'in spanish', 'in french'], 'creative_writing': ['story', 'poem', 'creative', 'write', 'compose', 'imagine'], 'reasoning': ['solve', 'calculate', 'reason', 'logic', 'think', 'analyze'], 'text_generation': ['generate', 'create', 'write', 'produce'] } prompt_lower = user_prompt.lower() detected_task = 'text_generation' for task, keywords in task_keywords.items(): if any(keyword in prompt_lower for keyword in keywords): detected_task = task break complexity = 'simple' if len(user_prompt.split()) > 50: complexity = 'complex' elif len(user_prompt.split()) > 20: complexity = 'moderate' entities = re.findall(r'\b[A-Z][a-z]+\b', user_prompt) return { 'task_type': detected_task, 'complexity': complexity, 'entities': entities[:5], 'intent': 'User wants to ' + detected_task.replace('_', ' '), 'domain': 'general', 'original_prompt': user_prompt, 'word_count': len(user_prompt.split()), 'char_count': len(user_prompt) }