Spaces:
Runtime error
Runtime error
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import re | |
| from datetime import datetime, timezone | |
| from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple | |
| from google.adk.agents import Agent | |
| from google.adk.agents.run_config import RunConfig, StreamingMode | |
| from google.adk.runners import InMemoryRunner | |
| from google.adk.tools import FunctionTool | |
| from google.genai import types | |
| from app.services.agentic_prompt import ( | |
| get_vector_search_prompt, | |
| get_web_search_prompt, | |
| ) | |
| from app.services.chathistory import ChatSession | |
| from app.services.environmental_condition import EnvironmentalData | |
| from app.services.tools import ( | |
| analyze_skin_image, | |
| convert_document_to_markdown, | |
| get_image_search, | |
| get_vector_search, | |
| get_web_search, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") | |
| DEFAULT_MODEL_NAME = os.getenv("GEMINI_MODEL", "gemini-2.0-flash-exp") | |
| PERSONALIZED_TOOL_NAME = "get_personalized_context" | |
| ENVIRONMENT_TOOL_NAME = "get_environmental_context" | |
| DOCUMENT_CONVERSION_TOOL_NAME = "convert_uploaded_document" | |
| IMAGE_ANALYSIS_TOOL_NAME = "analyze_skin_image" | |
| if not os.getenv("GOOGLE_API_KEY") and GOOGLE_API_KEY: | |
| os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY | |
| if os.name == "nt": | |
| try: | |
| asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) | |
| except Exception: | |
| pass | |
| class GoogleAgentService: | |
| """Chat orchestrator that streams responses from a Google ADK agent.""" | |
| def __init__( | |
| self, | |
| token: str, | |
| session_id: Optional[str] = None, | |
| document: Optional[Dict[str, Any]] = None, | |
| image: Optional[Dict[str, Any]] = None, | |
| ) -> None: | |
| self.token = token | |
| self.session_id = session_id | |
| self.chat_session = ChatSession(token, session_id) | |
| self.user_preferences = self._load_user_preferences() | |
| self.language = self.chat_session.get_language() or "english" | |
| self.user_profile = self._load_user_profile() | |
| self.user_city = self.chat_session.get_city() | |
| self.environment_data = self._load_environmental_data() | |
| self.document = document | |
| self.image = image | |
| async def process_message_async( | |
| self, query: str | |
| ) -> AsyncGenerator[Dict[str, Any], None]: | |
| if not GOOGLE_API_KEY: | |
| error = "Google API key is not configured." | |
| logger.error(error) | |
| yield {"type": "error", "content": error} | |
| return | |
| try: | |
| session_id = self._ensure_valid_session(query) | |
| agent_mode = "web" if self.user_preferences.get("websearch") else "vector" | |
| user_data = self._prepare_user_data() | |
| agent = self._build_agent(agent_mode, user_data) | |
| runner = InMemoryRunner(agent=agent) | |
| await runner.session_service.create_session( | |
| app_name=runner.app_name, | |
| user_id=self.chat_session.identity, | |
| session_id=session_id, | |
| ) | |
| user_message = types.Content( | |
| role="user", | |
| parts=[types.Part(text=query)], | |
| ) | |
| run_config = RunConfig(streaming_mode=StreamingMode.SSE) | |
| tool_calls: List[Dict[str, Any]] = [] | |
| tool_call_map: Dict[str, Dict[str, Any]] = {} | |
| collected_images: List[str] = [] | |
| collected_references: List[str] = [] | |
| streamed_text = "" | |
| final_text = "" | |
| pending_token_buffer = "" | |
| def emit_word_chunks(delta: str, *, final: bool = False) -> List[str]: | |
| nonlocal pending_token_buffer | |
| pending_token_buffer += delta | |
| chunks: List[str] = [] | |
| while pending_token_buffer: | |
| match = re.search(r'\s', pending_token_buffer) | |
| if not match: | |
| break | |
| idx = match.end() | |
| token = pending_token_buffer[:idx] | |
| pending_token_buffer = pending_token_buffer[idx:] | |
| if token: | |
| chunks.append(token) | |
| if final and pending_token_buffer: | |
| chunks.append(pending_token_buffer) | |
| pending_token_buffer = "" | |
| return chunks | |
| async for event in runner.run_async( | |
| user_id=self.chat_session.identity, | |
| session_id=session_id, | |
| new_message=user_message, | |
| run_config=run_config, | |
| ): | |
| if event.error_message: | |
| logger.error("Agent error: %s", event.error_message) | |
| yield {"type": "error", "content": event.error_message} | |
| return | |
| for function_call in event.get_function_calls(): | |
| call_entry = { | |
| "id": function_call.id, | |
| "tool_name": function_call.name, | |
| "arguments": function_call.args or {}, | |
| } | |
| tool_call_map[function_call.id] = call_entry | |
| tool_calls.append(call_entry) | |
| yield { | |
| "type": "tool_call", | |
| "id": function_call.id, | |
| "tool_name": function_call.name, | |
| "arguments": function_call.args or {}, | |
| } | |
| for function_response in event.get_function_responses(): | |
| response_payload = function_response.response or {} | |
| call_entry = tool_call_map.get(function_response.id) | |
| if call_entry is not None: | |
| call_entry["result"] = response_payload | |
| if isinstance(response_payload, dict): | |
| if function_response.name == "get_image_search": | |
| collected_images.extend(response_payload.get("images", [])) | |
| if response_payload.get("references"): | |
| collected_references.extend(response_payload["references"]) | |
| yield { | |
| "type": "tool_result", | |
| "id": function_response.id, | |
| "tool_name": function_response.name, | |
| "result": response_payload, | |
| } | |
| text_segment = self._extract_text(event) | |
| if not text_segment: | |
| continue | |
| if event.partial: | |
| streamed_text += text_segment | |
| for token in emit_word_chunks(text_segment): | |
| yield {"type": "chunk", "content": token} | |
| else: | |
| final_text = text_segment | |
| if streamed_text and text_segment.startswith(streamed_text): | |
| delta = text_segment[len(streamed_text) :] | |
| else: | |
| delta = text_segment | |
| if text_segment: | |
| streamed_text = text_segment | |
| for token in emit_word_chunks(delta, final=True): | |
| yield {"type": "chunk", "content": token} | |
| for leftover in emit_word_chunks("", final=True): | |
| if leftover: | |
| yield {"type": "chunk", "content": leftover} | |
| parsed_response = self._parse_agent_response(final_text or streamed_text) | |
| response_text, keywords, response_images, response_refs = parsed_response | |
| merged_images = self._dedupe_list(collected_images + response_images) | |
| merged_references = self._dedupe_list(collected_references + response_refs) | |
| context_chunks: List[str] = [] | |
| if self.document and self.document.get("path"): | |
| context_chunks.append(f"document:{self.document.get('path')}") | |
| if self.image and self.image.get("path"): | |
| context_chunks.append(f"image:{self.image.get('path')}") | |
| context_payload = " ".join(context_chunks) | |
| chat_payload = { | |
| "query": query, | |
| "response": response_text, | |
| "references": merged_references, | |
| "keywords": keywords, | |
| "images": merged_images, | |
| "context": context_payload, | |
| "timestamp": datetime.now(timezone.utc).isoformat(), | |
| "session_id": session_id, | |
| "tool_calls": tool_calls, | |
| } | |
| saved = self.chat_session.save_chat(chat_payload) | |
| yield { | |
| "type": "completed", | |
| "saved": saved, | |
| "session_id": session_id, | |
| "response": response_text, | |
| "keywords": keywords, | |
| "references": merged_references, | |
| "images": merged_images, | |
| "tool_calls": tool_calls, | |
| } | |
| except Exception as exc: | |
| logger.error("Agent streaming failure: %s", exc, exc_info=True) | |
| yield {"type": "error", "content": f"Generation failed: {exc}"} | |
| def _build_agent(self, mode: str, user_data: Dict[str, Any]) -> Agent: | |
| prompt = ( | |
| get_web_search_prompt(user_data) | |
| if mode == "web" | |
| else get_vector_search_prompt(user_data) | |
| ) | |
| search_tool = get_web_search if mode == "web" else get_vector_search | |
| tools: List[FunctionTool] = [] | |
| if self.image and self.image.get("path"): | |
| tools.append(FunctionTool(self._create_image_tool())) | |
| tools.extend( | |
| [ | |
| FunctionTool(search_tool), | |
| FunctionTool(get_image_search), | |
| ] | |
| ) | |
| if self.document and self.document.get("path"): | |
| tools.append(FunctionTool(self._create_document_tool())) | |
| if user_data.get("has_personalized_data"): | |
| personalized_tool = self._create_personalized_data_tool( | |
| user_data.get("personalized_data", "") | |
| ) | |
| tools.append(FunctionTool(personalized_tool)) | |
| if user_data.get("has_environmental_data"): | |
| environmental_tool = self._create_environmental_data_tool( | |
| user_data.get("environmental_payload") or {} | |
| ) | |
| tools.append(FunctionTool(environmental_tool)) | |
| agent = Agent( | |
| name="DermAI", | |
| model=DEFAULT_MODEL_NAME, | |
| instruction=prompt, | |
| tools=tools, | |
| ) | |
| return agent | |
| def _create_document_tool(self) -> Callable[..., Dict[str, Any]]: | |
| document_record = self.document or {} | |
| allowed_path = (document_record.get("path") or "").strip() | |
| def run_document_conversion( | |
| file_path: Optional[str] = None, | |
| file_extension: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| target_path = (file_path or allowed_path or "").replace("\\", "/").strip() | |
| if not target_path: | |
| return { | |
| "status": "error", | |
| "error_message": "file_path is required to convert a document.", | |
| } | |
| allowed_normalized = allowed_path.replace("\\", "/").strip() | |
| if allowed_normalized and target_path != allowed_normalized: | |
| return { | |
| "status": "error", | |
| "error_message": "The provided file_path does not match the uploaded document for this session.", | |
| } | |
| result = convert_document_to_markdown( | |
| file_path=target_path, | |
| file_extension=file_extension or document_record.get("extension"), | |
| ) | |
| if result.get("status") == "success": | |
| text_content = result.get("text_content") or "" | |
| result["preview"] = text_content[:1000] | |
| result["character_count"] = len(text_content) | |
| if allowed_path and result.get("source_path"): | |
| # Normalize to relative path for transparency | |
| result["source_path"] = allowed_path | |
| return result | |
| run_document_conversion.__name__ = DOCUMENT_CONVERSION_TOOL_NAME | |
| run_document_conversion.__doc__ = ( | |
| "Convert the user's uploaded dermatology document into Markdown text. " | |
| "Provide the `file_path` exactly as supplied in the conversation context." | |
| ) | |
| return run_document_conversion | |
| def _create_image_tool(self) -> Callable[..., Dict[str, Any]]: | |
| image_record = self.image or {} | |
| allowed_path = (image_record.get("path") or "").strip() | |
| def run_image_analysis( | |
| file_path: Optional[str] = None, | |
| language: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| target_path = (file_path or allowed_path or "").replace("\\", "/").strip() | |
| if not target_path: | |
| return { | |
| "status": "error", | |
| "error_message": "file_path is required to analyse the image.", | |
| } | |
| allowed_normalized = allowed_path.replace("\\", "/").strip() | |
| if allowed_normalized and target_path != allowed_normalized: | |
| return { | |
| "status": "error", | |
| "error_message": "The provided file_path does not match the uploaded image for this session.", | |
| } | |
| result = analyze_skin_image( | |
| file_path=target_path, | |
| language=language or self.language, | |
| ) | |
| if result.get("status") == "success" and allowed_path: | |
| result["image_path"] = allowed_path | |
| return result | |
| run_image_analysis.__name__ = IMAGE_ANALYSIS_TOOL_NAME | |
| run_image_analysis.__doc__ = ( | |
| "Analyse the user's uploaded skin image. Provide the `file_path` exactly as supplied " | |
| "in the conversation context to run the classifier." | |
| ) | |
| return run_image_analysis | |
| def _create_personalized_data_tool(self, data: str) -> Callable[[], Dict[str, Any]]: | |
| sanitized = (data or "").strip() | |
| def personalized_tool() -> Dict[str, Any]: | |
| return { | |
| "status": "success", | |
| "generated_at": datetime.now(timezone.utc).isoformat(), | |
| "personalized_data": sanitized, | |
| } | |
| personalized_tool.__name__ = PERSONALIZED_TOOL_NAME | |
| personalized_tool.__doc__ = ( | |
| "Return questionnaire-derived personalization details for the current user." | |
| ) | |
| return personalized_tool | |
| def _create_environmental_data_tool( | |
| self, data: Dict[str, Any] | |
| ) -> Callable[[], Dict[str, Any]]: | |
| snapshot = dict(data) if isinstance(data, dict) else {} | |
| city = self.user_city | |
| def environmental_tool() -> Dict[str, Any]: | |
| return { | |
| "status": "success", | |
| "city": city, | |
| "retrieved_at": datetime.now(timezone.utc).isoformat(), | |
| "environmental_data": snapshot, | |
| } | |
| environmental_tool.__name__ = ENVIRONMENT_TOOL_NAME | |
| environmental_tool.__doc__ = ( | |
| "Return the cached environmental conditions for the user's location." | |
| ) | |
| return environmental_tool | |
| def _load_user_preferences(self) -> Dict[str, Any]: | |
| try: | |
| return self.chat_session.get_user_preferences() | |
| except Exception as exc: | |
| logger.warning("Failed to load user preferences: %s", exc) | |
| return { | |
| "websearch": False, | |
| "keywords": True, | |
| "references": True, | |
| "personalized_recommendations": False, | |
| "environmental_recommendations": False, | |
| } | |
| def _load_user_profile(self) -> Dict[str, Any]: | |
| try: | |
| profile = self.chat_session.get_name_and_age() or {} | |
| return { | |
| "name": profile.get("name", "Patient"), | |
| "age": profile.get("age", "Unknown"), | |
| } | |
| except Exception as exc: | |
| logger.warning("Failed to load profile: %s", exc) | |
| return {"name": "Patient", "age": "Unknown"} | |
| def _load_environmental_data(self) -> Optional[Dict[str, Any]]: | |
| try: | |
| if ( | |
| self.user_preferences.get("environmental_recommendations") | |
| and self.user_city | |
| ): | |
| data = EnvironmentalData(self.user_city).get_environmental_data() | |
| if data: | |
| return data | |
| except Exception as exc: | |
| logger.warning("Failed to load environmental data: %s", exc) | |
| return None | |
| def _load_personalized_data(self) -> str: | |
| try: | |
| if self.user_preferences.get("personalized_recommendations"): | |
| data = self.chat_session.get_personalized_recommendation() | |
| return data or "" | |
| except Exception as exc: | |
| logger.warning("Failed to load personalized data: %s", exc) | |
| return "" | |
| def _prepare_user_data(self) -> Dict[str, Any]: | |
| personalized_data = self._load_personalized_data() | |
| environmental_payload = ( | |
| dict(self.environment_data) | |
| if isinstance(self.environment_data, dict) | |
| else {} | |
| ) | |
| has_personalized_data = bool(personalized_data) | |
| has_environmental_data = bool(environmental_payload) | |
| document_info = None | |
| if self.document and self.document.get("path"): | |
| document_info = { | |
| "path": self.document.get("path"), | |
| "name": self.document.get("name") or "Uploaded document", | |
| "type": self.document.get("type"), | |
| "extension": self.document.get("extension"), | |
| } | |
| image_info = None | |
| if self.image and self.image.get("path"): | |
| image_info = { | |
| "path": self.image.get("path"), | |
| "name": self.image.get("name") or "Uploaded image", | |
| "type": self.image.get("type"), | |
| "extension": self.image.get("extension"), | |
| "prompt": self.image.get("prompt"), | |
| } | |
| return { | |
| "name": self.user_profile.get("name"), | |
| "age": self.user_profile.get("age"), | |
| "language": self.language, | |
| "personalized_recommendations": self.user_preferences.get( | |
| "personalized_recommendations" | |
| ), | |
| "environmental_recommendations": self.user_preferences.get( | |
| "environmental_recommendations" | |
| ), | |
| "personalized_data": personalized_data, | |
| "environmental_data": json.dumps(environmental_payload) | |
| if has_environmental_data | |
| else "", | |
| "has_personalized_data": has_personalized_data, | |
| "has_environmental_data": has_environmental_data, | |
| "personalized_tool_name": PERSONALIZED_TOOL_NAME | |
| if has_personalized_data | |
| else None, | |
| "environmental_tool_name": ENVIRONMENT_TOOL_NAME | |
| if has_environmental_data | |
| else None, | |
| "environmental_payload": environmental_payload, | |
| "include_keywords": self.user_preferences.get("keywords", True), | |
| "include_references": self.user_preferences.get("references", True), | |
| "include_images": True, | |
| "recent_history": self._get_recent_history(), | |
| "document_info": document_info, | |
| "document_tool_name": DOCUMENT_CONVERSION_TOOL_NAME | |
| if document_info | |
| else None, | |
| "image_info": image_info, | |
| "image_tool_name": IMAGE_ANALYSIS_TOOL_NAME if image_info else None, | |
| } | |
| def _get_recent_history(self, limit: int = 10) -> str: | |
| try: | |
| if not self.session_id: | |
| return "" | |
| self.chat_session.load_chat_history() | |
| history_items = self.chat_session.get_chat_history() or [] | |
| if not history_items: | |
| return "" | |
| recent = history_items[-limit:] | |
| formatted = [] | |
| for entry in recent: | |
| user_q = entry.get("query") or "" | |
| bot_a = entry.get("response") or "" | |
| if user_q: | |
| formatted.append(f"User: {user_q}") | |
| if bot_a: | |
| formatted.append(f"Dr DermAI: {bot_a}") | |
| return "\n".join(formatted[-limit * 2:]) | |
| except Exception as exc: | |
| logger.warning("Failed to load recent history: %s", exc) | |
| return "" | |
| def _ensure_valid_session(self, title: Optional[str] = None) -> str: | |
| if not self.session_id or not self.session_id.strip(): | |
| self.chat_session.create_new_session(title=title) | |
| self.session_id = self.chat_session.session_id | |
| else: | |
| try: | |
| if not self.chat_session.validate_session(self.session_id, title=title): | |
| self.chat_session.create_new_session(title=title) | |
| self.session_id = self.chat_session.session_id | |
| except Exception: | |
| self.chat_session.create_new_session(title=title) | |
| self.session_id = self.chat_session.session_id | |
| return self.session_id | |
| def _extract_text(event) -> str: | |
| if not event.content or not event.content.parts: | |
| return "" | |
| parts: List[str] = [] | |
| for part in event.content.parts: | |
| if part.text: | |
| parts.append(part.text) | |
| return "".join(parts) | |
| def _strip_code_fence(text: str) -> str: | |
| stripped = text.strip() | |
| if stripped.startswith("```") and stripped.endswith("```"): | |
| body = stripped.strip("`") | |
| if body.lower().startswith("json"): | |
| body = body[4:] | |
| stripped = body | |
| return stripped.strip() | |
| def _parse_agent_response(self, text: str) -> Tuple[str, List[str], List[str], List[str]]: | |
| cleaned = self._strip_code_fence(text) | |
| if not cleaned: | |
| return "", [], [], [] | |
| try: | |
| payload = json.loads(cleaned) | |
| except json.JSONDecodeError: | |
| logger.warning("Unable to parse agent JSON response; returning raw text.") | |
| return cleaned, [], [], [] | |
| if not isinstance(payload, dict): | |
| return cleaned, [], [], [] | |
| response_text = payload.get("response") or cleaned | |
| raw_keywords = payload.get("keywords", []) | |
| if isinstance(raw_keywords, list): | |
| keywords = [str(item).strip() for item in raw_keywords if str(item).strip()] | |
| elif raw_keywords: | |
| keywords = [str(raw_keywords).strip()] | |
| else: | |
| keywords = [] | |
| raw_images = payload.get("images", []) | |
| if isinstance(raw_images, list): | |
| images = [str(item).strip() for item in raw_images if str(item).strip()] | |
| elif raw_images: | |
| images = [str(raw_images).strip()] | |
| else: | |
| images = [] | |
| raw_refs = payload.get("references", []) | |
| if isinstance(raw_refs, list): | |
| references = [str(item).strip() for item in raw_refs if str(item).strip()] | |
| elif raw_refs: | |
| references = [str(raw_refs).strip()] | |
| else: | |
| references = [] | |
| return response_text, keywords, images, references | |
| def _dedupe_list(items: List[str]) -> List[str]: | |
| seen = set() | |
| deduped: List[str] = [] | |
| for item in items: | |
| if not item: | |
| continue | |
| if item in seen: | |
| continue | |
| seen.add(item) | |
| deduped.append(item) | |
| return deduped | |