File size: 3,905 Bytes
aeaf873
 
95ee675
 
aeaf873
 
 
 
 
 
 
 
 
 
 
95ee675
aeaf873
 
 
 
 
 
95ee675
 
aeaf873
 
95ee675
aeaf873
 
95ee675
aeaf873
 
95ee675
 
 
 
 
 
 
 
 
 
 
 
aeaf873
 
 
95ee675
aeaf873
 
 
95ee675
aeaf873
 
2a07ca7
95ee675
 
2a07ca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95ee675
 
 
aeaf873
95ee675
 
 
 
aeaf873
 
 
 
 
 
 
 
95ee675
aeaf873
 
 
 
 
 
 
 
95ee675
aeaf873
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import json
import yaml
from pathlib import Path
from typing import List, Dict, Any
import google.generativeai as genai
from loguru import logger
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from dotenv import load_dotenv

# Load Env
load_dotenv()

API_KEY = os.getenv("GEMINI_API_KEY")
if not API_KEY:
    logger.critical("GEMINI_API_KEY is missing")
    raise ValueError("GEMINI_API_KEY not found in environment")

genai.configure(api_key=API_KEY)

class LLMEngine:
    """
    LLM Engine with Externalized Configuration.
    Loads prompts from config/prompts.yaml to ensure code/data separation.
    """

    def __init__(self, model_name: str = "gemini-2.5-flash"):
        self.model_name = model_name
        self.model = genai.GenerativeModel(model_name)
        self.prompts = self._load_prompts()
        logger.info(f"LLM Engine initialized with {model_name}")

    def _load_prompts(self) -> Dict[str, Any]:
        """Loads prompt templates from the config directory."""
        # Resolves path relative to this file: src/core/../../config/prompts.yaml
        base_path = Path(__file__).resolve().parent.parent.parent
        config_path = base_path / "config" / "prompts.yaml"
        
        if not config_path.exists():
            raise FileNotFoundError(f"Configuration file not found at: {config_path}")
            
        with open(config_path, "r") as f:
            return yaml.safe_load(f)

    @retry(
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=4, max=10),
        retry=retry_if_exception_type(Exception),
        reraise=True
    )
    async def generate_test_cases(self, requirements_text: str) -> List[Dict[str, Any]]:
        logger.debug("Constructing prompt from configuration...")
        
        try:
            # Load config sections
            config = self.prompts["test_generation"]
            
            # 1. System Role & Context
            system_role = config.get("system_role", "")
            
            # 2. Main Instruction (Injecting the dynamic text)
            instruction = config.get("instruction", "").format(
                requirements_text=requirements_text[:30000]
            )
            
            # 3. Few-Shot Examples (NEW: Capturing your expanded section)
            examples = config.get("few_shot_examples", "")
            
            # 4. Output Schema
            fmt = config.get("output_format", "")
            
            # Construct the Full Context Window
            # Order matters: Role -> Task -> Examples -> Strict JSON Format
            full_prompt = (
                f"{system_role}\n\n"
                f"{instruction}\n\n"
                f"### REFERENCE EXAMPLES:\n{examples}\n\n"
                f"### REQUIRED OUTPUT FORMAT:\n{fmt}"
            )

            # Async call
            response = await self.model.generate_content_async(full_prompt)
            return self._parse_json_response(response.text)
            
        except KeyError as e:
            logger.error(f"Missing configuration key in prompts.yaml: {e}")
            raise e
        except Exception as e:
            logger.error(f"LLM Generation Failed: {e}")
            raise e

    def _parse_json_response(self, text: str) -> List[Dict[str, Any]]:
        clean_text = text.strip()
        if clean_text.startswith("```json"):
            clean_text = clean_text[7:-3]
        elif clean_text.startswith("```"):
            clean_text = clean_text[3:-3]
            
        try:
            data = json.loads(clean_text)
            if not isinstance(data, list):
                raise ValueError("Output is not a list of test cases")
            return data
        except json.JSONDecodeError as e:
            logger.error("Failed to decode JSON response")
            raise ValueError(f"Malformed JSON from LLM: {e}")