Marketeer / core_logic /chat_agent.py
KPrashanth's picture
Upload 47 files
85ba912 verified
"""
Agent-style chat for Marketeer using LangChain tools.
- Uses ChatHuggingFace from llm_config.get_local_chat_model()
- Uses rewrite / tone tools from rewrite_tools.py
- Implements a tiny tool-calling loop with .bind_tools() (no AgentExecutor).
"""
from typing import Any, Dict, List, Tuple, Union
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.tools import BaseTool
from core_logic.llm_config import get_local_chat_model
from core_logic.copy_pipeline import CopyRequest
from helpers.platform_styles import get_platform_style # <-- dataclass style
from core_logic.rewrite_tools import get_rewrite_tools
Message = Union[HumanMessage, AIMessage]
# --------------------------------------------------------------------
# Helpers for platform style and history
# --------------------------------------------------------------------
def _get_style_attr(style: Any, field: str, default: str = "") -> str:
"""
Safe attribute getter for PlatformStyle dataclass (or dict fallback).
This handles both:
- dataclass PlatformStyle (preferred)
- dict-like (in case of accidental mix)
"""
if style is None:
return default
# Dataclass / object path
if hasattr(style, field):
value = getattr(style, field)
return default if value is None else str(value)
# Dict path (just in case)
if isinstance(style, dict):
value = style.get(field, default)
return default if value is None else str(value)
return default
def _build_system_prompt(req: CopyRequest) -> str:
"""
Build a system instruction that explains:
- you are a marketing copywriter
- you know the campaign context
- you may optionally use tools to rewrite/edit
"""
# This comes from helpers.platform_styles and returns a PlatformStyle dataclass
style = get_platform_style(req.platform_name or "Instagram")
# Access attributes directly (NO dict-style indexing anywhere)
voice = getattr(style, "voice", "")
emoji_guideline = getattr(style, "emoji_guideline", "")
hashtag_guideline = getattr(style, "hashtag_guideline", "")
formatting_guideline = getattr(style, "formatting_guideline", "")
extra_notes = getattr(style, "extra_notes", "")
return f"""
You are Marketeer, an expert marketing copywriter.
You help users:
- write first-draft posts
- refine tone
- shorten or expand posts
- adapt copy across platforms
Campaign context:
- Brand: {req.brand}
- Product / offer: {req.product}
- Audience: {req.audience}
- Goal: {req.goal}
- Platform: {req.platform_name}
- Tone: {req.tone}
- CTA style: {req.cta_style}
- Extra context: {req.extra_context}
Platform style guidelines:
- Voice: {voice}
- Emoji usage: {emoji_guideline}
- Hashtags: {hashtag_guideline}
- Formatting: {formatting_guideline}
- Extra notes: {extra_notes}
You may have access to special tools that help you:
- adjust tone
- shorten or expand text
- remove or add emojis
- tweak style
When you respond:
- If the user clearly wants a simple answer, respond directly.
- If the user is asking to rewrite existing text (e.g. "shorten this",
"make it more professional", "remove emojis"), feel free to call tools
if they are available.
- Always return clean, user-ready copy (no JSON, no debug).
""".strip()
def _build_message_history(history_pairs: List[List[str]]) -> List[Message]:
"""
Convert [[user, assistant], ...] into LangChain Human/AI messages.
"""
messages: List[Message] = []
for pair in history_pairs:
if not pair or len(pair) != 2:
continue
user_text, assistant_text = pair
if user_text:
messages.append(HumanMessage(content=user_text))
if assistant_text:
messages.append(AIMessage(content=assistant_text))
return messages
def _get_tool_map(tools: List[BaseTool]) -> Dict[str, BaseTool]:
"""
Convenience map: tool_name -> tool object.
"""
return {tool.name: tool for tool in tools}
# --------------------------------------------------------------------
# Main agent entry point
# --------------------------------------------------------------------
def agent_chat_turn(
req: CopyRequest,
user_message: str,
history_pairs: List[List[str]] | None = None,
) -> Tuple[str, str, list]:
...
history_pairs = history_pairs or []
# 1) Build base messages: "system" prompt as a HumanMessage + history + new user
instructions = _build_system_prompt(req)
# IMPORTANT: use HumanMessage here, not SystemMessage
system_msg = HumanMessage(content=instructions)
history_msgs = _build_message_history(history_pairs)
new_user_msg = HumanMessage(content=user_message)
messages: List[Union[Message, ToolMessage]] = (
[system_msg] + history_msgs + [new_user_msg]
)
# 2) Prepare tools & model
tools: List[BaseTool] = get_rewrite_tools()
tool_map = _get_tool_map(tools)
llm = get_local_chat_model()
llm_with_tools = llm.bind_tools(tools)
# 3) First model call (decide whether to use tools)
ai_msg: AIMessage = llm_with_tools.invoke(messages)
raw_first = ai_msg.content or ""
# If the model does not request any tools, just return its answer
if not getattr(ai_msg, "tool_calls", None):
final_text = raw_first.strip()
return final_text, raw_first, []
# 4) Execute any requested tools
messages.append(ai_msg)
tool_messages: List[ToolMessage] = []
for tool_call in ai_msg.tool_calls:
tool_name = tool_call.get("name")
args = tool_call.get("args", {})
call_id = tool_call.get("id") or ""
tool = tool_map.get(tool_name)
if tool is None:
tool_output = f"Tool '{tool_name}' is not available."
else:
# LangChain tools usually implement .invoke()
try:
tool_output = tool.invoke(args)
except Exception as e:
tool_output = f"Tool '{tool_name}' failed with error: {e}"
tool_msg = ToolMessage(
content=str(tool_output),
tool_call_id=call_id,
)
tool_messages.append(tool_msg)
messages.extend(tool_messages)
# 5) Second model call: let the LLM see tool results and answer
final_ai: AIMessage = llm_with_tools.invoke(messages)
final_text = (final_ai.content or "").strip()
raw_second = final_ai.content or ""
audit: list = [] # reserved for tool call logs if you want later
return final_text, raw_second, audit