Pronto / core /input_analysis_agent.py
rivapereira123's picture
Update core/input_analysis_agent.py
1770cc7 verified
"""
Input Analysis Agent
Analyzes user prompts to understand task type and intent.
"""
import re
from typing import Dict, List, Any
from langchain_community.llms import OpenAI
from langchain.prompts import PromptTemplate
import os
class InputAnalysisAgent:
def __init__(self):
self.llm = OpenAI(
openai_api_key=os.getenv("OPENAI_API_KEY"),
temperature=0.1,
model_name="gpt-3.5-turbo-instruct"
)
self.analysis_prompt = PromptTemplate(
input_variables=["user_prompt"],
template="""
Analyze the following user prompt and extract key information:
User Prompt: "{user_prompt}"
Please provide:
1. Task Type (choose **exactly one** from the list below):
[text_generation, question_answering, code_generation, summarization, translation, reasoning, creative_writing, data_analysis]
2. Complexity Level (simple, moderate, complex)
3. Key Entities (extract important nouns, concepts, or variables)
4. Intent (what the user wants to achieve)
5. Domain (e.g., general, technical, academic, creative, business)
Format your response as:
TASK_TYPE: [task_type]
COMPLEXITY: [complexity_level]
ENTITIES: [comma-separated list]
INTENT: [brief description]
DOMAIN: [domain]
"""
)
def analyze_prompt(self, user_prompt: str) -> Dict[str, Any]:
prompt_text = self.analysis_prompt.format(user_prompt=user_prompt)
try:
response = self.llm.invoke(prompt_text)
print(f"\n[DEBUG] LLM raw response: {response}\n")
analysis_result = response.strip() # Always a string with instruct models
parsed_result = self._parse_analysis_result(analysis_result)
parsed_result["original_prompt"] = user_prompt
parsed_result["word_count"] = len(user_prompt.split())
parsed_result["char_count"] = len(user_prompt)
return parsed_result
except Exception as e:
print(f"[ERROR] Falling back due to exception: {e}")
return self._fallback_analysis(user_prompt)
def _parse_analysis_result(self, analysis_result: str) -> Dict[str, Any]:
result = {}
lines = analysis_result.strip().split('\n')
for line in lines:
if ':' in line:
key, value = line.split(':', 1)
key = key.strip().lower()
value = value.strip()
if key == 'task_type':
result['task_type'] = value
elif key == 'complexity':
result['complexity'] = value
elif key == 'entities':
result['entities'] = [entity.strip() for entity in value.split(',') if entity.strip()]
elif key == 'intent':
result['intent'] = value
elif key == 'domain':
result['domain'] = value
return result
def _fallback_analysis(self, user_prompt: str) -> Dict[str, Any]:
task_keywords = {
'question_answering': ['what', 'how', 'why', 'when', 'where', 'who', 'explain', 'answer'],
'code_generation': ['code', 'function', 'program', 'script', 'algorithm', 'implement'],
'summarization': ['summarize', 'summary', 'brief', 'overview', 'key points'],
'translation': ['translate', 'translation', 'convert to', 'in spanish', 'in french'],
'creative_writing': ['story', 'poem', 'creative', 'write', 'compose', 'imagine'],
'reasoning': ['solve', 'calculate', 'reason', 'logic', 'think', 'analyze'],
'text_generation': ['generate', 'create', 'write', 'produce']
}
prompt_lower = user_prompt.lower()
detected_task = 'text_generation'
for task, keywords in task_keywords.items():
if any(keyword in prompt_lower for keyword in keywords):
detected_task = task
break
complexity = 'simple'
if len(user_prompt.split()) > 50:
complexity = 'complex'
elif len(user_prompt.split()) > 20:
complexity = 'moderate'
entities = re.findall(r'\b[A-Z][a-z]+\b', user_prompt)
return {
'task_type': detected_task,
'complexity': complexity,
'entities': entities[:5],
'intent': 'User wants to ' + detected_task.replace('_', ' '),
'domain': 'general',
'original_prompt': user_prompt,
'word_count': len(user_prompt.split()),
'char_count': len(user_prompt)
}