File size: 5,062 Bytes
63bcd5a b09149c 63bcd5a 4552666 63bcd5a 4552666 63bcd5a b09149c 63bcd5a 4552666 63bcd5a 4552666 63bcd5a 5ec2fc9 63bcd5a 5ec2fc9 63bcd5a 4552666 63bcd5a 4552666 63bcd5a 809b701 63bcd5a 4552666 63bcd5a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | import time
import logging
from typing import List
from google import genai
from src.recommendation_engine.config import (
GEMINI_API_KEY,
MODEL_CANDIDATES,
IDEA_TEMPERATURE,
FEATURE_TEMPERATURE,
CHAT_TEMPERATURE,
INTENT_TEMPERATURE,
IDEA_MAX_TOKENS,
FEATURE_MAX_TOKENS,
CHAT_MAX_TOKENS,
INTENT_MAX_TOKENS,
FULL_PROJECT_MAX_TOKENS,
TOP_P,
TOP_K,
MAX_RETRIES,
RETRY_DELAY_SECONDS,
ENABLE_LOGGING
)
from src.recommendation_engine.validator import validate_generated_list
logger = logging.getLogger(__name__)
if ENABLE_LOGGING:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s"
)
class LLMProviderError(Exception):
def __init__(self, message: str, status_code: int = 503):
super().__init__(message)
self.message = message
self.status_code = status_code
def classify_provider_error(error: Exception):
text = str(error).lower()
if (
"reported as leaked" in text
or "permission_denied" in text
or "api key" in text
or "403" in text
):
return LLMProviderError(
"Gemini API key was rejected. Create a new key, update .env, and restart the server.",
status_code=503
)
if (
"resource_exhausted" in text
or "quota" in text
or "rate limit" in text
or "429" in text
):
return LLMProviderError(
"Gemini quota or rate limit is exhausted. Try again later or use another API key/project.",
status_code=429
)
return None
client = genai.Client(api_key=GEMINI_API_KEY)
def extract_text(response) -> str:
if not response:
return ""
text = getattr(response, "text", None)
if text:
return text.strip()
try:
candidates = getattr(response, "candidates", [])
if candidates:
parts = candidates[0].content.parts
return " ".join(
p.text for p in parts if hasattr(p, "text")
).strip()
except Exception:
pass
return ""
def get_temperature(task: str) -> float:
return {
"idea": IDEA_TEMPERATURE,
"feature": FEATURE_TEMPERATURE,
"intent": INTENT_TEMPERATURE,
}.get(task, CHAT_TEMPERATURE)
def get_max_tokens(task: str) -> int:
return {
"idea": IDEA_MAX_TOKENS,
"feature": FEATURE_MAX_TOKENS,
"intent": INTENT_MAX_TOKENS,
"full_project": FULL_PROJECT_MAX_TOKENS,
}.get(task, CHAT_MAX_TOKENS)
def safe_prompt(prompt: str, max_chars: int = 12000) -> str:
return prompt[-max_chars:]
def is_bad_response(text: str) -> bool:
if not text:
return True
text = text.strip()
if len(text) < 3:
return True
bad_phrases = [
"as an ai",
"i can help you",
"let me know"
]
lower = text.lower()
if all(p in lower for p in bad_phrases):
return True
return False
def generate_text(
prompt: str,
task: str = "chat",
temperature=None
) -> str:
prompt = safe_prompt(prompt)
if temperature is None:
temperature = get_temperature(task)
max_tokens = get_max_tokens(task)
for model_name in MODEL_CANDIDATES:
for attempt in range(MAX_RETRIES):
try:
logger.info(
f"[LLM] model={model_name} | task={task} | attempt={attempt+1}"
)
response = client.models.generate_content(
model=model_name,
contents=prompt,
config={
"temperature": temperature,
"top_p": TOP_P,
"top_k": TOP_K,
"max_output_tokens": max_tokens
}
)
text = extract_text(response)
if is_bad_response(text):
logger.warning("[LLM] Weak response, using anyway")
return text
return text
except Exception as e:
logger.warning(f"[LLM ERROR] {e}")
provider_error = classify_provider_error(e)
if provider_error:
if provider_error.status_code == 429 and attempt < MAX_RETRIES - 1:
sleep_time = (RETRY_DELAY_SECONDS * 5) * (attempt + 1)
logger.info(f"[LLM 429] Rate limited. Retrying in {sleep_time}s...")
time.sleep(sleep_time)
continue
raise provider_error
time.sleep(RETRY_DELAY_SECONDS * (attempt + 1))
logger.info(f"[LLM] switching model...")
logger.error("All LLM models failed")
return ""
def generate_list(prompt: str, task="chat") -> List[str]:
text = generate_text(prompt, task=task)
return validate_generated_list(text)
|