Spaces:
Sleeping
Sleeping
File size: 5,797 Bytes
030be06 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | import os
from typing import Dict, Generator, List, Optional, Tuple, Union
from google import genai
from google.genai.types import GenerateContentConfig, GoogleSearch, Tool
from src.config.settings import GEMINI_API_KEY, MODEL_ID, MODEL_TEMPERATURE, SYSTEM_INSTRUCTION
def _get_api_key() -> Optional[str]:
return GEMINI_API_KEY or os.getenv("GOOGLE_API_KEY")
ChatMessage = Dict[str, str]
ChatLog = List[Union[ChatMessage, Tuple[str, str]]]
def _build_prompt(question: str, chat_log: ChatLog) -> str:
if not chat_log:
return question
lines = ["Conversation so far:"]
for item in chat_log:
if isinstance(item, dict):
role = item.get("role", "assistant")
message = item.get("content", "")
else:
role, message = item
if role.lower().startswith("user") or role.lower().startswith("you"):
label = "User"
elif role.lower().startswith("assistant") or role.lower().startswith("ai"):
label = "Assistant"
else:
label = role
lines.append(f"{label}: {message}")
lines.append(f"User: {question}")
return "\n".join(lines)
def _create_client() -> genai.Client:
api_key = _get_api_key()
if not api_key:
raise ValueError("Missing API key. Set GEMINI_API_KEY or GOOGLE_API_KEY.")
return genai.Client(api_key=api_key)
def _build_config(
use_web_search: bool,
temperature: float,
system_instruction: str,
) -> GenerateContentConfig:
tools = [Tool(google_search=GoogleSearch())] if use_web_search else []
return GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
tools=tools,
)
def google_search_query(
question: str,
use_web_search: bool,
chat_log: Optional[ChatLog] = None,
model_id: str = MODEL_ID,
temperature: float = MODEL_TEMPERATURE,
system_instruction: str = SYSTEM_INSTRUCTION,
) -> Tuple[str, str]:
chat_log = chat_log or []
try:
client = _create_client()
prompt = _build_prompt(question, chat_log)
config = _build_config(use_web_search, temperature, system_instruction)
response = client.models.generate_content(
model=model_id,
contents=prompt,
config=config,
)
ai_response = response.text or ""
search_results = "Web search not used."
if use_web_search:
search_results = (
response.candidates[0]
.grounding_metadata.search_entry_point.rendered_content
)
return ai_response, search_results
except Exception as exc:
return f"Error: {exc}", ""
def google_search_query_stream(
question: str,
use_web_search: bool,
chat_log: Optional[ChatLog] = None,
model_id: str = MODEL_ID,
temperature: float = MODEL_TEMPERATURE,
system_instruction: str = SYSTEM_INSTRUCTION,
) -> Generator[Tuple[str, str], None, None]:
chat_log = chat_log or []
try:
client = _create_client()
prompt = _build_prompt(question, chat_log)
config = _build_config(use_web_search, temperature, system_instruction)
response_stream = client.models.generate_content_stream(
model=model_id,
contents=prompt,
config=config,
)
collected = []
for chunk in response_stream:
if chunk.text:
collected.append(chunk.text)
yield "".join(collected), ""
final_text = "".join(collected)
search_results = "Web search not used."
if use_web_search:
# Fetch search results metadata in a follow-up call.
final_response = client.models.generate_content(
model=model_id,
contents=prompt,
config=config,
)
search_results = (
final_response.candidates[0]
.grounding_metadata.search_entry_point.rendered_content
)
yield final_text, search_results
except Exception as exc:
yield f"Error: {exc}", ""
def update_chatbot(
question: str,
use_web_search: bool,
chat_log: Optional[List[ChatMessage]],
model_id: str,
temperature: float,
stream: bool,
) -> Generator[List[ChatMessage], None, None] | List[ChatMessage]:
if chat_log is None:
chat_log = []
if not question:
return chat_log
chat_log.append({"role": "user", "content": question})
if stream:
for partial, search_results in google_search_query_stream(
question,
use_web_search,
chat_log,
model_id,
temperature,
):
if chat_log and chat_log[-1].get("role") == "assistant":
chat_log[-1] = {"role": "assistant", "content": partial}
else:
chat_log.append({"role": "assistant", "content": partial})
yield chat_log
if use_web_search and search_results:
chat_log.append(
{
"role": "assistant",
"content": f"Web Search Results:\n{search_results}",
}
)
yield chat_log
else:
ai_response, search_results = google_search_query(
question,
use_web_search,
chat_log,
model_id,
temperature,
)
chat_log.append({"role": "assistant", "content": ai_response})
if use_web_search:
chat_log.append(
{
"role": "assistant",
"content": f"Web Search Results:\n{search_results}",
}
)
return chat_log
|