derm-ai / app /services /google_agent_service.py
muhammadnoman76's picture
update
b7dfc73
import asyncio
import json
import logging
import os
import re
from datetime import datetime, timezone
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple
from google.adk.agents import Agent
from google.adk.agents.run_config import RunConfig, StreamingMode
from google.adk.runners import InMemoryRunner
from google.adk.tools import FunctionTool
from google.genai import types
from app.services.agentic_prompt import (
get_vector_search_prompt,
get_web_search_prompt,
)
from app.services.chathistory import ChatSession
from app.services.environmental_condition import EnvironmentalData
from app.services.tools import (
analyze_skin_image,
convert_document_to_markdown,
get_image_search,
get_vector_search,
get_web_search,
)
logger = logging.getLogger(__name__)
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
DEFAULT_MODEL_NAME = os.getenv("GEMINI_MODEL", "gemini-2.0-flash-exp")
PERSONALIZED_TOOL_NAME = "get_personalized_context"
ENVIRONMENT_TOOL_NAME = "get_environmental_context"
DOCUMENT_CONVERSION_TOOL_NAME = "convert_uploaded_document"
IMAGE_ANALYSIS_TOOL_NAME = "analyze_skin_image"
if not os.getenv("GOOGLE_API_KEY") and GOOGLE_API_KEY:
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY
if os.name == "nt":
try:
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
except Exception:
pass
class GoogleAgentService:
"""Chat orchestrator that streams responses from a Google ADK agent."""
def __init__(
self,
token: str,
session_id: Optional[str] = None,
document: Optional[Dict[str, Any]] = None,
image: Optional[Dict[str, Any]] = None,
) -> None:
self.token = token
self.session_id = session_id
self.chat_session = ChatSession(token, session_id)
self.user_preferences = self._load_user_preferences()
self.language = self.chat_session.get_language() or "english"
self.user_profile = self._load_user_profile()
self.user_city = self.chat_session.get_city()
self.environment_data = self._load_environmental_data()
self.document = document
self.image = image
async def process_message_async(
self, query: str
) -> AsyncGenerator[Dict[str, Any], None]:
if not GOOGLE_API_KEY:
error = "Google API key is not configured."
logger.error(error)
yield {"type": "error", "content": error}
return
try:
session_id = self._ensure_valid_session(query)
agent_mode = "web" if self.user_preferences.get("websearch") else "vector"
user_data = self._prepare_user_data()
agent = self._build_agent(agent_mode, user_data)
runner = InMemoryRunner(agent=agent)
await runner.session_service.create_session(
app_name=runner.app_name,
user_id=self.chat_session.identity,
session_id=session_id,
)
user_message = types.Content(
role="user",
parts=[types.Part(text=query)],
)
run_config = RunConfig(streaming_mode=StreamingMode.SSE)
tool_calls: List[Dict[str, Any]] = []
tool_call_map: Dict[str, Dict[str, Any]] = {}
collected_images: List[str] = []
collected_references: List[str] = []
streamed_text = ""
final_text = ""
pending_token_buffer = ""
def emit_word_chunks(delta: str, *, final: bool = False) -> List[str]:
nonlocal pending_token_buffer
pending_token_buffer += delta
chunks: List[str] = []
while pending_token_buffer:
match = re.search(r'\s', pending_token_buffer)
if not match:
break
idx = match.end()
token = pending_token_buffer[:idx]
pending_token_buffer = pending_token_buffer[idx:]
if token:
chunks.append(token)
if final and pending_token_buffer:
chunks.append(pending_token_buffer)
pending_token_buffer = ""
return chunks
async for event in runner.run_async(
user_id=self.chat_session.identity,
session_id=session_id,
new_message=user_message,
run_config=run_config,
):
if event.error_message:
logger.error("Agent error: %s", event.error_message)
yield {"type": "error", "content": event.error_message}
return
for function_call in event.get_function_calls():
call_entry = {
"id": function_call.id,
"tool_name": function_call.name,
"arguments": function_call.args or {},
}
tool_call_map[function_call.id] = call_entry
tool_calls.append(call_entry)
yield {
"type": "tool_call",
"id": function_call.id,
"tool_name": function_call.name,
"arguments": function_call.args or {},
}
for function_response in event.get_function_responses():
response_payload = function_response.response or {}
call_entry = tool_call_map.get(function_response.id)
if call_entry is not None:
call_entry["result"] = response_payload
if isinstance(response_payload, dict):
if function_response.name == "get_image_search":
collected_images.extend(response_payload.get("images", []))
if response_payload.get("references"):
collected_references.extend(response_payload["references"])
yield {
"type": "tool_result",
"id": function_response.id,
"tool_name": function_response.name,
"result": response_payload,
}
text_segment = self._extract_text(event)
if not text_segment:
continue
if event.partial:
streamed_text += text_segment
for token in emit_word_chunks(text_segment):
yield {"type": "chunk", "content": token}
else:
final_text = text_segment
if streamed_text and text_segment.startswith(streamed_text):
delta = text_segment[len(streamed_text) :]
else:
delta = text_segment
if text_segment:
streamed_text = text_segment
for token in emit_word_chunks(delta, final=True):
yield {"type": "chunk", "content": token}
for leftover in emit_word_chunks("", final=True):
if leftover:
yield {"type": "chunk", "content": leftover}
parsed_response = self._parse_agent_response(final_text or streamed_text)
response_text, keywords, response_images, response_refs = parsed_response
merged_images = self._dedupe_list(collected_images + response_images)
merged_references = self._dedupe_list(collected_references + response_refs)
context_chunks: List[str] = []
if self.document and self.document.get("path"):
context_chunks.append(f"document:{self.document.get('path')}")
if self.image and self.image.get("path"):
context_chunks.append(f"image:{self.image.get('path')}")
context_payload = " ".join(context_chunks)
chat_payload = {
"query": query,
"response": response_text,
"references": merged_references,
"keywords": keywords,
"images": merged_images,
"context": context_payload,
"timestamp": datetime.now(timezone.utc).isoformat(),
"session_id": session_id,
"tool_calls": tool_calls,
}
saved = self.chat_session.save_chat(chat_payload)
yield {
"type": "completed",
"saved": saved,
"session_id": session_id,
"response": response_text,
"keywords": keywords,
"references": merged_references,
"images": merged_images,
"tool_calls": tool_calls,
}
except Exception as exc:
logger.error("Agent streaming failure: %s", exc, exc_info=True)
yield {"type": "error", "content": f"Generation failed: {exc}"}
def _build_agent(self, mode: str, user_data: Dict[str, Any]) -> Agent:
prompt = (
get_web_search_prompt(user_data)
if mode == "web"
else get_vector_search_prompt(user_data)
)
search_tool = get_web_search if mode == "web" else get_vector_search
tools: List[FunctionTool] = []
if self.image and self.image.get("path"):
tools.append(FunctionTool(self._create_image_tool()))
tools.extend(
[
FunctionTool(search_tool),
FunctionTool(get_image_search),
]
)
if self.document and self.document.get("path"):
tools.append(FunctionTool(self._create_document_tool()))
if user_data.get("has_personalized_data"):
personalized_tool = self._create_personalized_data_tool(
user_data.get("personalized_data", "")
)
tools.append(FunctionTool(personalized_tool))
if user_data.get("has_environmental_data"):
environmental_tool = self._create_environmental_data_tool(
user_data.get("environmental_payload") or {}
)
tools.append(FunctionTool(environmental_tool))
agent = Agent(
name="DermAI",
model=DEFAULT_MODEL_NAME,
instruction=prompt,
tools=tools,
)
return agent
def _create_document_tool(self) -> Callable[..., Dict[str, Any]]:
document_record = self.document or {}
allowed_path = (document_record.get("path") or "").strip()
def run_document_conversion(
file_path: Optional[str] = None,
file_extension: Optional[str] = None,
) -> Dict[str, Any]:
target_path = (file_path or allowed_path or "").replace("\\", "/").strip()
if not target_path:
return {
"status": "error",
"error_message": "file_path is required to convert a document.",
}
allowed_normalized = allowed_path.replace("\\", "/").strip()
if allowed_normalized and target_path != allowed_normalized:
return {
"status": "error",
"error_message": "The provided file_path does not match the uploaded document for this session.",
}
result = convert_document_to_markdown(
file_path=target_path,
file_extension=file_extension or document_record.get("extension"),
)
if result.get("status") == "success":
text_content = result.get("text_content") or ""
result["preview"] = text_content[:1000]
result["character_count"] = len(text_content)
if allowed_path and result.get("source_path"):
# Normalize to relative path for transparency
result["source_path"] = allowed_path
return result
run_document_conversion.__name__ = DOCUMENT_CONVERSION_TOOL_NAME
run_document_conversion.__doc__ = (
"Convert the user's uploaded dermatology document into Markdown text. "
"Provide the `file_path` exactly as supplied in the conversation context."
)
return run_document_conversion
def _create_image_tool(self) -> Callable[..., Dict[str, Any]]:
image_record = self.image or {}
allowed_path = (image_record.get("path") or "").strip()
def run_image_analysis(
file_path: Optional[str] = None,
language: Optional[str] = None,
) -> Dict[str, Any]:
target_path = (file_path or allowed_path or "").replace("\\", "/").strip()
if not target_path:
return {
"status": "error",
"error_message": "file_path is required to analyse the image.",
}
allowed_normalized = allowed_path.replace("\\", "/").strip()
if allowed_normalized and target_path != allowed_normalized:
return {
"status": "error",
"error_message": "The provided file_path does not match the uploaded image for this session.",
}
result = analyze_skin_image(
file_path=target_path,
language=language or self.language,
)
if result.get("status") == "success" and allowed_path:
result["image_path"] = allowed_path
return result
run_image_analysis.__name__ = IMAGE_ANALYSIS_TOOL_NAME
run_image_analysis.__doc__ = (
"Analyse the user's uploaded skin image. Provide the `file_path` exactly as supplied "
"in the conversation context to run the classifier."
)
return run_image_analysis
def _create_personalized_data_tool(self, data: str) -> Callable[[], Dict[str, Any]]:
sanitized = (data or "").strip()
def personalized_tool() -> Dict[str, Any]:
return {
"status": "success",
"generated_at": datetime.now(timezone.utc).isoformat(),
"personalized_data": sanitized,
}
personalized_tool.__name__ = PERSONALIZED_TOOL_NAME
personalized_tool.__doc__ = (
"Return questionnaire-derived personalization details for the current user."
)
return personalized_tool
def _create_environmental_data_tool(
self, data: Dict[str, Any]
) -> Callable[[], Dict[str, Any]]:
snapshot = dict(data) if isinstance(data, dict) else {}
city = self.user_city
def environmental_tool() -> Dict[str, Any]:
return {
"status": "success",
"city": city,
"retrieved_at": datetime.now(timezone.utc).isoformat(),
"environmental_data": snapshot,
}
environmental_tool.__name__ = ENVIRONMENT_TOOL_NAME
environmental_tool.__doc__ = (
"Return the cached environmental conditions for the user's location."
)
return environmental_tool
def _load_user_preferences(self) -> Dict[str, Any]:
try:
return self.chat_session.get_user_preferences()
except Exception as exc:
logger.warning("Failed to load user preferences: %s", exc)
return {
"websearch": False,
"keywords": True,
"references": True,
"personalized_recommendations": False,
"environmental_recommendations": False,
}
def _load_user_profile(self) -> Dict[str, Any]:
try:
profile = self.chat_session.get_name_and_age() or {}
return {
"name": profile.get("name", "Patient"),
"age": profile.get("age", "Unknown"),
}
except Exception as exc:
logger.warning("Failed to load profile: %s", exc)
return {"name": "Patient", "age": "Unknown"}
def _load_environmental_data(self) -> Optional[Dict[str, Any]]:
try:
if (
self.user_preferences.get("environmental_recommendations")
and self.user_city
):
data = EnvironmentalData(self.user_city).get_environmental_data()
if data:
return data
except Exception as exc:
logger.warning("Failed to load environmental data: %s", exc)
return None
def _load_personalized_data(self) -> str:
try:
if self.user_preferences.get("personalized_recommendations"):
data = self.chat_session.get_personalized_recommendation()
return data or ""
except Exception as exc:
logger.warning("Failed to load personalized data: %s", exc)
return ""
def _prepare_user_data(self) -> Dict[str, Any]:
personalized_data = self._load_personalized_data()
environmental_payload = (
dict(self.environment_data)
if isinstance(self.environment_data, dict)
else {}
)
has_personalized_data = bool(personalized_data)
has_environmental_data = bool(environmental_payload)
document_info = None
if self.document and self.document.get("path"):
document_info = {
"path": self.document.get("path"),
"name": self.document.get("name") or "Uploaded document",
"type": self.document.get("type"),
"extension": self.document.get("extension"),
}
image_info = None
if self.image and self.image.get("path"):
image_info = {
"path": self.image.get("path"),
"name": self.image.get("name") or "Uploaded image",
"type": self.image.get("type"),
"extension": self.image.get("extension"),
"prompt": self.image.get("prompt"),
}
return {
"name": self.user_profile.get("name"),
"age": self.user_profile.get("age"),
"language": self.language,
"personalized_recommendations": self.user_preferences.get(
"personalized_recommendations"
),
"environmental_recommendations": self.user_preferences.get(
"environmental_recommendations"
),
"personalized_data": personalized_data,
"environmental_data": json.dumps(environmental_payload)
if has_environmental_data
else "",
"has_personalized_data": has_personalized_data,
"has_environmental_data": has_environmental_data,
"personalized_tool_name": PERSONALIZED_TOOL_NAME
if has_personalized_data
else None,
"environmental_tool_name": ENVIRONMENT_TOOL_NAME
if has_environmental_data
else None,
"environmental_payload": environmental_payload,
"include_keywords": self.user_preferences.get("keywords", True),
"include_references": self.user_preferences.get("references", True),
"include_images": True,
"recent_history": self._get_recent_history(),
"document_info": document_info,
"document_tool_name": DOCUMENT_CONVERSION_TOOL_NAME
if document_info
else None,
"image_info": image_info,
"image_tool_name": IMAGE_ANALYSIS_TOOL_NAME if image_info else None,
}
def _get_recent_history(self, limit: int = 10) -> str:
try:
if not self.session_id:
return ""
self.chat_session.load_chat_history()
history_items = self.chat_session.get_chat_history() or []
if not history_items:
return ""
recent = history_items[-limit:]
formatted = []
for entry in recent:
user_q = entry.get("query") or ""
bot_a = entry.get("response") or ""
if user_q:
formatted.append(f"User: {user_q}")
if bot_a:
formatted.append(f"Dr DermAI: {bot_a}")
return "\n".join(formatted[-limit * 2:])
except Exception as exc:
logger.warning("Failed to load recent history: %s", exc)
return ""
def _ensure_valid_session(self, title: Optional[str] = None) -> str:
if not self.session_id or not self.session_id.strip():
self.chat_session.create_new_session(title=title)
self.session_id = self.chat_session.session_id
else:
try:
if not self.chat_session.validate_session(self.session_id, title=title):
self.chat_session.create_new_session(title=title)
self.session_id = self.chat_session.session_id
except Exception:
self.chat_session.create_new_session(title=title)
self.session_id = self.chat_session.session_id
return self.session_id
@staticmethod
def _extract_text(event) -> str:
if not event.content or not event.content.parts:
return ""
parts: List[str] = []
for part in event.content.parts:
if part.text:
parts.append(part.text)
return "".join(parts)
@staticmethod
def _strip_code_fence(text: str) -> str:
stripped = text.strip()
if stripped.startswith("```") and stripped.endswith("```"):
body = stripped.strip("`")
if body.lower().startswith("json"):
body = body[4:]
stripped = body
return stripped.strip()
def _parse_agent_response(self, text: str) -> Tuple[str, List[str], List[str], List[str]]:
cleaned = self._strip_code_fence(text)
if not cleaned:
return "", [], [], []
try:
payload = json.loads(cleaned)
except json.JSONDecodeError:
logger.warning("Unable to parse agent JSON response; returning raw text.")
return cleaned, [], [], []
if not isinstance(payload, dict):
return cleaned, [], [], []
response_text = payload.get("response") or cleaned
raw_keywords = payload.get("keywords", [])
if isinstance(raw_keywords, list):
keywords = [str(item).strip() for item in raw_keywords if str(item).strip()]
elif raw_keywords:
keywords = [str(raw_keywords).strip()]
else:
keywords = []
raw_images = payload.get("images", [])
if isinstance(raw_images, list):
images = [str(item).strip() for item in raw_images if str(item).strip()]
elif raw_images:
images = [str(raw_images).strip()]
else:
images = []
raw_refs = payload.get("references", [])
if isinstance(raw_refs, list):
references = [str(item).strip() for item in raw_refs if str(item).strip()]
elif raw_refs:
references = [str(raw_refs).strip()]
else:
references = []
return response_text, keywords, images, references
@staticmethod
def _dedupe_list(items: List[str]) -> List[str]:
seen = set()
deduped: List[str] = []
for item in items:
if not item:
continue
if item in seen:
continue
seen.add(item)
deduped.append(item)
return deduped