""" Agent-style chat for Marketeer using LangChain tools. - Uses ChatHuggingFace from llm_config.get_local_chat_model() - Uses rewrite / tone tools from rewrite_tools.py - Implements a tiny tool-calling loop with .bind_tools() (no AgentExecutor). """ from typing import Any, Dict, List, Tuple, Union from langchain_core.messages import ( AIMessage, HumanMessage, SystemMessage, ToolMessage, ) from langchain_core.tools import BaseTool from core_logic.llm_config import get_local_chat_model from core_logic.copy_pipeline import CopyRequest from helpers.platform_styles import get_platform_style # <-- dataclass style from core_logic.rewrite_tools import get_rewrite_tools Message = Union[HumanMessage, AIMessage] # -------------------------------------------------------------------- # Helpers for platform style and history # -------------------------------------------------------------------- def _get_style_attr(style: Any, field: str, default: str = "") -> str: """ Safe attribute getter for PlatformStyle dataclass (or dict fallback). This handles both: - dataclass PlatformStyle (preferred) - dict-like (in case of accidental mix) """ if style is None: return default # Dataclass / object path if hasattr(style, field): value = getattr(style, field) return default if value is None else str(value) # Dict path (just in case) if isinstance(style, dict): value = style.get(field, default) return default if value is None else str(value) return default def _build_system_prompt(req: CopyRequest) -> str: """ Build a system instruction that explains: - you are a marketing copywriter - you know the campaign context - you may optionally use tools to rewrite/edit """ # This comes from helpers.platform_styles and returns a PlatformStyle dataclass style = get_platform_style(req.platform_name or "Instagram") # Access attributes directly (NO dict-style indexing anywhere) voice = getattr(style, "voice", "") emoji_guideline = getattr(style, "emoji_guideline", "") hashtag_guideline = getattr(style, "hashtag_guideline", "") formatting_guideline = getattr(style, "formatting_guideline", "") extra_notes = getattr(style, "extra_notes", "") return f""" You are Marketeer, an expert marketing copywriter. You help users: - write first-draft posts - refine tone - shorten or expand posts - adapt copy across platforms Campaign context: - Brand: {req.brand} - Product / offer: {req.product} - Audience: {req.audience} - Goal: {req.goal} - Platform: {req.platform_name} - Tone: {req.tone} - CTA style: {req.cta_style} - Extra context: {req.extra_context} Platform style guidelines: - Voice: {voice} - Emoji usage: {emoji_guideline} - Hashtags: {hashtag_guideline} - Formatting: {formatting_guideline} - Extra notes: {extra_notes} You may have access to special tools that help you: - adjust tone - shorten or expand text - remove or add emojis - tweak style When you respond: - If the user clearly wants a simple answer, respond directly. - If the user is asking to rewrite existing text (e.g. "shorten this", "make it more professional", "remove emojis"), feel free to call tools if they are available. - Always return clean, user-ready copy (no JSON, no debug). """.strip() def _build_message_history(history_pairs: List[List[str]]) -> List[Message]: """ Convert [[user, assistant], ...] into LangChain Human/AI messages. """ messages: List[Message] = [] for pair in history_pairs: if not pair or len(pair) != 2: continue user_text, assistant_text = pair if user_text: messages.append(HumanMessage(content=user_text)) if assistant_text: messages.append(AIMessage(content=assistant_text)) return messages def _get_tool_map(tools: List[BaseTool]) -> Dict[str, BaseTool]: """ Convenience map: tool_name -> tool object. """ return {tool.name: tool for tool in tools} # -------------------------------------------------------------------- # Main agent entry point # -------------------------------------------------------------------- def agent_chat_turn( req: CopyRequest, user_message: str, history_pairs: List[List[str]] | None = None, ) -> Tuple[str, str, list]: ... history_pairs = history_pairs or [] # 1) Build base messages: "system" prompt as a HumanMessage + history + new user instructions = _build_system_prompt(req) # IMPORTANT: use HumanMessage here, not SystemMessage system_msg = HumanMessage(content=instructions) history_msgs = _build_message_history(history_pairs) new_user_msg = HumanMessage(content=user_message) messages: List[Union[Message, ToolMessage]] = ( [system_msg] + history_msgs + [new_user_msg] ) # 2) Prepare tools & model tools: List[BaseTool] = get_rewrite_tools() tool_map = _get_tool_map(tools) llm = get_local_chat_model() llm_with_tools = llm.bind_tools(tools) # 3) First model call (decide whether to use tools) ai_msg: AIMessage = llm_with_tools.invoke(messages) raw_first = ai_msg.content or "" # If the model does not request any tools, just return its answer if not getattr(ai_msg, "tool_calls", None): final_text = raw_first.strip() return final_text, raw_first, [] # 4) Execute any requested tools messages.append(ai_msg) tool_messages: List[ToolMessage] = [] for tool_call in ai_msg.tool_calls: tool_name = tool_call.get("name") args = tool_call.get("args", {}) call_id = tool_call.get("id") or "" tool = tool_map.get(tool_name) if tool is None: tool_output = f"Tool '{tool_name}' is not available." else: # LangChain tools usually implement .invoke() try: tool_output = tool.invoke(args) except Exception as e: tool_output = f"Tool '{tool_name}' failed with error: {e}" tool_msg = ToolMessage( content=str(tool_output), tool_call_id=call_id, ) tool_messages.append(tool_msg) messages.extend(tool_messages) # 5) Second model call: let the LLM see tool results and answer final_ai: AIMessage = llm_with_tools.invoke(messages) final_text = (final_ai.content or "").strip() raw_second = final_ai.content or "" audit: list = [] # reserved for tool call logs if you want later return final_text, raw_second, audit