File size: 4,771 Bytes
2e1d07a
 
 
 
 
 
 
4976deb
2e1d07a
 
 
 
 
 
 
a9a06c7
a7f9b26
2e1d07a
a9a06c7
2e1d07a
 
 
 
9ad84bb
2e1d07a
9ad84bb
2e1d07a
9f216f7
 
2e1d07a
 
 
 
9ad84bb
2e1d07a
 
 
 
 
 
 
 
9ad84bb
 
1770cc7
 
 
 
8a5de46
1770cc7
8a5de46
1770cc7
 
 
 
 
9ad84bb
1770cc7
 
 
da36a11
2e1d07a
 
 
 
 
 
 
 
9ad84bb
2e1d07a
 
 
 
 
 
 
 
 
 
9ad84bb
2e1d07a
9ad84bb
2e1d07a
 
 
 
 
 
 
 
 
 
9ad84bb
2e1d07a
9ad84bb
2e1d07a
 
 
 
9ad84bb
2e1d07a
 
 
 
 
9ad84bb
2e1d07a
9ad84bb
2e1d07a
 
 
9ad84bb
2e1d07a
 
 
 
 
 
1770cc7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Input Analysis Agent
Analyzes user prompts to understand task type and intent.
"""

import re
from typing import Dict, List, Any
from langchain_community.llms import OpenAI
from langchain.prompts import PromptTemplate
import os

class InputAnalysisAgent:
    def __init__(self):
        self.llm = OpenAI(
            openai_api_key=os.getenv("OPENAI_API_KEY"),
            temperature=0.1,
            model_name="gpt-3.5-turbo-instruct"
        )

        self.analysis_prompt = PromptTemplate(
            input_variables=["user_prompt"],
            template="""
            Analyze the following user prompt and extract key information:

            User Prompt: "{user_prompt}"

            Please provide:
            1. Task Type (choose **exactly one** from the list below):
            [text_generation, question_answering, code_generation, summarization, translation, reasoning, creative_writing, data_analysis]
            2. Complexity Level (simple, moderate, complex)
            3. Key Entities (extract important nouns, concepts, or variables)
            4. Intent (what the user wants to achieve)
            5. Domain (e.g., general, technical, academic, creative, business)

            Format your response as:
            TASK_TYPE: [task_type]
            COMPLEXITY: [complexity_level]
            ENTITIES: [comma-separated list]
            INTENT: [brief description]
            DOMAIN: [domain]
            """
        )

    def analyze_prompt(self, user_prompt: str) -> Dict[str, Any]:
        prompt_text = self.analysis_prompt.format(user_prompt=user_prompt)
        try:
            response = self.llm.invoke(prompt_text)
            print(f"\n[DEBUG] LLM raw response: {response}\n")

            analysis_result = response.strip()  # Always a string with instruct models

            parsed_result = self._parse_analysis_result(analysis_result)
            parsed_result["original_prompt"] = user_prompt
            parsed_result["word_count"] = len(user_prompt.split())
            parsed_result["char_count"] = len(user_prompt)
            return parsed_result

        except Exception as e:
            print(f"[ERROR] Falling back due to exception: {e}")
            return self._fallback_analysis(user_prompt)

    def _parse_analysis_result(self, analysis_result: str) -> Dict[str, Any]:
        result = {}
        lines = analysis_result.strip().split('\n')
        for line in lines:
            if ':' in line:
                key, value = line.split(':', 1)
                key = key.strip().lower()
                value = value.strip()

                if key == 'task_type':
                    result['task_type'] = value
                elif key == 'complexity':
                    result['complexity'] = value
                elif key == 'entities':
                    result['entities'] = [entity.strip() for entity in value.split(',') if entity.strip()]
                elif key == 'intent':
                    result['intent'] = value
                elif key == 'domain':
                    result['domain'] = value

        return result

    def _fallback_analysis(self, user_prompt: str) -> Dict[str, Any]:
        task_keywords = {
            'question_answering': ['what', 'how', 'why', 'when', 'where', 'who', 'explain', 'answer'],
            'code_generation': ['code', 'function', 'program', 'script', 'algorithm', 'implement'],
            'summarization': ['summarize', 'summary', 'brief', 'overview', 'key points'],
            'translation': ['translate', 'translation', 'convert to', 'in spanish', 'in french'],
            'creative_writing': ['story', 'poem', 'creative', 'write', 'compose', 'imagine'],
            'reasoning': ['solve', 'calculate', 'reason', 'logic', 'think', 'analyze'],
            'text_generation': ['generate', 'create', 'write', 'produce']
        }

        prompt_lower = user_prompt.lower()
        detected_task = 'text_generation'
        for task, keywords in task_keywords.items():
            if any(keyword in prompt_lower for keyword in keywords):
                detected_task = task
                break

        complexity = 'simple'
        if len(user_prompt.split()) > 50:
            complexity = 'complex'
        elif len(user_prompt.split()) > 20:
            complexity = 'moderate'

        entities = re.findall(r'\b[A-Z][a-z]+\b', user_prompt)

        return {
            'task_type': detected_task,
            'complexity': complexity,
            'entities': entities[:5],
            'intent': 'User wants to ' + detected_task.replace('_', ' '),
            'domain': 'general',
            'original_prompt': user_prompt,
            'word_count': len(user_prompt.split()),
            'char_count': len(user_prompt)
        }