Spaces:
Sleeping
Sleeping
File size: 4,771 Bytes
2e1d07a 4976deb 2e1d07a a9a06c7 a7f9b26 2e1d07a a9a06c7 2e1d07a 9ad84bb 2e1d07a 9ad84bb 2e1d07a 9f216f7 2e1d07a 9ad84bb 2e1d07a 9ad84bb 1770cc7 8a5de46 1770cc7 8a5de46 1770cc7 9ad84bb 1770cc7 da36a11 2e1d07a 9ad84bb 2e1d07a 9ad84bb 2e1d07a 9ad84bb 2e1d07a 9ad84bb 2e1d07a 9ad84bb 2e1d07a 9ad84bb 2e1d07a 9ad84bb 2e1d07a 9ad84bb 2e1d07a 9ad84bb 2e1d07a 1770cc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
"""
Input Analysis Agent
Analyzes user prompts to understand task type and intent.
"""
import re
from typing import Dict, List, Any
from langchain_community.llms import OpenAI
from langchain.prompts import PromptTemplate
import os
class InputAnalysisAgent:
def __init__(self):
self.llm = OpenAI(
openai_api_key=os.getenv("OPENAI_API_KEY"),
temperature=0.1,
model_name="gpt-3.5-turbo-instruct"
)
self.analysis_prompt = PromptTemplate(
input_variables=["user_prompt"],
template="""
Analyze the following user prompt and extract key information:
User Prompt: "{user_prompt}"
Please provide:
1. Task Type (choose **exactly one** from the list below):
[text_generation, question_answering, code_generation, summarization, translation, reasoning, creative_writing, data_analysis]
2. Complexity Level (simple, moderate, complex)
3. Key Entities (extract important nouns, concepts, or variables)
4. Intent (what the user wants to achieve)
5. Domain (e.g., general, technical, academic, creative, business)
Format your response as:
TASK_TYPE: [task_type]
COMPLEXITY: [complexity_level]
ENTITIES: [comma-separated list]
INTENT: [brief description]
DOMAIN: [domain]
"""
)
def analyze_prompt(self, user_prompt: str) -> Dict[str, Any]:
prompt_text = self.analysis_prompt.format(user_prompt=user_prompt)
try:
response = self.llm.invoke(prompt_text)
print(f"\n[DEBUG] LLM raw response: {response}\n")
analysis_result = response.strip() # Always a string with instruct models
parsed_result = self._parse_analysis_result(analysis_result)
parsed_result["original_prompt"] = user_prompt
parsed_result["word_count"] = len(user_prompt.split())
parsed_result["char_count"] = len(user_prompt)
return parsed_result
except Exception as e:
print(f"[ERROR] Falling back due to exception: {e}")
return self._fallback_analysis(user_prompt)
def _parse_analysis_result(self, analysis_result: str) -> Dict[str, Any]:
result = {}
lines = analysis_result.strip().split('\n')
for line in lines:
if ':' in line:
key, value = line.split(':', 1)
key = key.strip().lower()
value = value.strip()
if key == 'task_type':
result['task_type'] = value
elif key == 'complexity':
result['complexity'] = value
elif key == 'entities':
result['entities'] = [entity.strip() for entity in value.split(',') if entity.strip()]
elif key == 'intent':
result['intent'] = value
elif key == 'domain':
result['domain'] = value
return result
def _fallback_analysis(self, user_prompt: str) -> Dict[str, Any]:
task_keywords = {
'question_answering': ['what', 'how', 'why', 'when', 'where', 'who', 'explain', 'answer'],
'code_generation': ['code', 'function', 'program', 'script', 'algorithm', 'implement'],
'summarization': ['summarize', 'summary', 'brief', 'overview', 'key points'],
'translation': ['translate', 'translation', 'convert to', 'in spanish', 'in french'],
'creative_writing': ['story', 'poem', 'creative', 'write', 'compose', 'imagine'],
'reasoning': ['solve', 'calculate', 'reason', 'logic', 'think', 'analyze'],
'text_generation': ['generate', 'create', 'write', 'produce']
}
prompt_lower = user_prompt.lower()
detected_task = 'text_generation'
for task, keywords in task_keywords.items():
if any(keyword in prompt_lower for keyword in keywords):
detected_task = task
break
complexity = 'simple'
if len(user_prompt.split()) > 50:
complexity = 'complex'
elif len(user_prompt.split()) > 20:
complexity = 'moderate'
entities = re.findall(r'\b[A-Z][a-z]+\b', user_prompt)
return {
'task_type': detected_task,
'complexity': complexity,
'entities': entities[:5],
'intent': 'User wants to ' + detected_task.replace('_', ' '),
'domain': 'general',
'original_prompt': user_prompt,
'word_count': len(user_prompt.split()),
'char_count': len(user_prompt)
}
|