Spaces:
Sleeping
Sleeping
| """ | |
| Input Analysis Agent | |
| Analyzes user prompts to understand task type and intent. | |
| """ | |
| import re | |
| from typing import Dict, List, Any | |
| from langchain_community.llms import OpenAI | |
| from langchain.prompts import PromptTemplate | |
| import os | |
| class InputAnalysisAgent: | |
| def __init__(self): | |
| self.llm = OpenAI( | |
| openai_api_key=os.getenv("OPENAI_API_KEY"), | |
| temperature=0.1, | |
| model_name="gpt-3.5-turbo-instruct" | |
| ) | |
| self.analysis_prompt = PromptTemplate( | |
| input_variables=["user_prompt"], | |
| template=""" | |
| Analyze the following user prompt and extract key information: | |
| User Prompt: "{user_prompt}" | |
| Please provide: | |
| 1. Task Type (choose **exactly one** from the list below): | |
| [text_generation, question_answering, code_generation, summarization, translation, reasoning, creative_writing, data_analysis] | |
| 2. Complexity Level (simple, moderate, complex) | |
| 3. Key Entities (extract important nouns, concepts, or variables) | |
| 4. Intent (what the user wants to achieve) | |
| 5. Domain (e.g., general, technical, academic, creative, business) | |
| Format your response as: | |
| TASK_TYPE: [task_type] | |
| COMPLEXITY: [complexity_level] | |
| ENTITIES: [comma-separated list] | |
| INTENT: [brief description] | |
| DOMAIN: [domain] | |
| """ | |
| ) | |
| def analyze_prompt(self, user_prompt: str) -> Dict[str, Any]: | |
| prompt_text = self.analysis_prompt.format(user_prompt=user_prompt) | |
| try: | |
| response = self.llm.invoke(prompt_text) | |
| print(f"\n[DEBUG] LLM raw response: {response}\n") | |
| analysis_result = response.strip() # Always a string with instruct models | |
| parsed_result = self._parse_analysis_result(analysis_result) | |
| parsed_result["original_prompt"] = user_prompt | |
| parsed_result["word_count"] = len(user_prompt.split()) | |
| parsed_result["char_count"] = len(user_prompt) | |
| return parsed_result | |
| except Exception as e: | |
| print(f"[ERROR] Falling back due to exception: {e}") | |
| return self._fallback_analysis(user_prompt) | |
| def _parse_analysis_result(self, analysis_result: str) -> Dict[str, Any]: | |
| result = {} | |
| lines = analysis_result.strip().split('\n') | |
| for line in lines: | |
| if ':' in line: | |
| key, value = line.split(':', 1) | |
| key = key.strip().lower() | |
| value = value.strip() | |
| if key == 'task_type': | |
| result['task_type'] = value | |
| elif key == 'complexity': | |
| result['complexity'] = value | |
| elif key == 'entities': | |
| result['entities'] = [entity.strip() for entity in value.split(',') if entity.strip()] | |
| elif key == 'intent': | |
| result['intent'] = value | |
| elif key == 'domain': | |
| result['domain'] = value | |
| return result | |
| def _fallback_analysis(self, user_prompt: str) -> Dict[str, Any]: | |
| task_keywords = { | |
| 'question_answering': ['what', 'how', 'why', 'when', 'where', 'who', 'explain', 'answer'], | |
| 'code_generation': ['code', 'function', 'program', 'script', 'algorithm', 'implement'], | |
| 'summarization': ['summarize', 'summary', 'brief', 'overview', 'key points'], | |
| 'translation': ['translate', 'translation', 'convert to', 'in spanish', 'in french'], | |
| 'creative_writing': ['story', 'poem', 'creative', 'write', 'compose', 'imagine'], | |
| 'reasoning': ['solve', 'calculate', 'reason', 'logic', 'think', 'analyze'], | |
| 'text_generation': ['generate', 'create', 'write', 'produce'] | |
| } | |
| prompt_lower = user_prompt.lower() | |
| detected_task = 'text_generation' | |
| for task, keywords in task_keywords.items(): | |
| if any(keyword in prompt_lower for keyword in keywords): | |
| detected_task = task | |
| break | |
| complexity = 'simple' | |
| if len(user_prompt.split()) > 50: | |
| complexity = 'complex' | |
| elif len(user_prompt.split()) > 20: | |
| complexity = 'moderate' | |
| entities = re.findall(r'\b[A-Z][a-z]+\b', user_prompt) | |
| return { | |
| 'task_type': detected_task, | |
| 'complexity': complexity, | |
| 'entities': entities[:5], | |
| 'intent': 'User wants to ' + detected_task.replace('_', ' '), | |
| 'domain': 'general', | |
| 'original_prompt': user_prompt, | |
| 'word_count': len(user_prompt.split()), | |
| 'char_count': len(user_prompt) | |
| } | |