Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import json | |
| import random | |
| import logging | |
| import torch | |
| import yaml | |
| from datetime import datetime, timedelta | |
| from typing import Any, Dict, List, Optional, TypedDict | |
| from dotenv import load_dotenv | |
| from langgraph.graph import StateGraph, END | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress TF logs | |
| _GENERATOR = None | |
| _CODEFence_RE = re.compile(r"```(?:json)?\s*([\s\S]*?)\s*```", re.IGNORECASE) | |
| DEFAULT_CONFIG = { | |
| "matching": { | |
| "MODEL_NAME": "mistralai/Mistral-7B-Instruct-v0.2", | |
| "HF_DEVICE_MAP": "auto", | |
| "MAX_NEW_TOKENS": 512, | |
| "TEMPERATURE": 0.2, | |
| "TOP_P": 0.9, | |
| "TOP_K_RETURN": 10, | |
| }, | |
| "postgen": { | |
| "MODEL_NAME": "mistralai/Mistral-7B-Instruct-v0.1", | |
| "HF_DEVICE_MAP": "auto", | |
| "MAX_NEW_TOKENS": 512, | |
| "TEMPERATURE": 0.2, | |
| "TOP_P": 0.9, | |
| }, | |
| "scheduling": { | |
| "rules_file": "./rule_based_scheduling_data.json", | |
| "timezone_offset": 0 | |
| }, | |
| "providers": { | |
| "hf": { | |
| "token_matching": os.getenv("mistralcopilothf"), | |
| "token_gen": os.getenv("mistralcopilothf"), | |
| } | |
| } | |
| } | |
| def _get_hf_generator_match(): | |
| """ | |
| Create (once) a Hugging Face text-generation pipeline for Mistral. | |
| Model-only (no mock). Raises if token/gated repo issues occur. | |
| """ | |
| global _GENERATOR | |
| if _GENERATOR is not None: | |
| return _GENERATOR | |
| import os | |
| import torch | |
| from transformers import pipeline | |
| token = DEFAULT_CONFIG["providers"]["hf"]["token_matching"] | |
| if not token: | |
| raise RuntimeError( | |
| "Hugging Face token not found. Set env var HUGGINGFACE_TOKEN (or HF_TOKEN)." | |
| ) | |
| # dtype selection | |
| if torch.cuda.is_available(): | |
| major, _ = torch.cuda.get_device_capability() | |
| torch_dtype = torch.bfloat16 if major >= 8 else torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| try: | |
| _GENERATOR = pipeline( | |
| "text-generation", | |
| model=DEFAULT_CONFIG["matching"]["MODEL_NAME"], | |
| device_map=DEFAULT_CONFIG["matching"]["HF_DEVICE_MAP"], | |
| torch_dtype=torch_dtype, | |
| token=token, | |
| ) | |
| except Exception as e: | |
| # Surface helpful error if gated | |
| raise RuntimeError( | |
| f"Failed to load model . " | |
| "If it's a gated repo, request access and ensure your token has it. " | |
| f"Original error: {e}" | |
| ) | |
| return _GENERATOR | |
| def _normalize_product(p: dict) -> dict: | |
| """ | |
| Accept product with either Go-style TitleCase or pythonic snake/camel. | |
| Return a normalized dict with lowercase keys used by the prompt. | |
| """ | |
| # handle multiple possible casings | |
| def g(k): | |
| return ( | |
| p.get(k) | |
| or p.get(k.lower()) | |
| or p.get(k.capitalize()) | |
| or p.get(k.replace("_", "")) | |
| or p.get(k.upper()) | |
| ) | |
| # Options should be list of {"name":..., "value":...} | |
| options = g("Options") or g("options") or [] | |
| # cast price to string (your Go struct has string price) | |
| price_val = g("Price") | |
| if isinstance(price_val, (int, float)): | |
| price_val = f"{price_val:.2f}" | |
| return { | |
| "id": g("ID") or g("Id") or g("id"), | |
| "name": g("Name") or g("name"), | |
| "category": g("Category") or g("category"), | |
| "type": g("Type") or g("type"), | |
| "price": price_val or "", | |
| "currency": g("Currency") or g("currency") or "", | |
| "description": g("Description") or g("description") or "", | |
| "stock_quantity": g("StockQuantity") or g("stock_quantity") or 0, | |
| "sku": g("SKU") or g("Sku") or g("sku") or "", | |
| "images": g("Images") or g("images") or [], | |
| "options": options, | |
| "on_sale": bool(g("OnSale") if g("OnSale") is not None else g("on_sale") or False), | |
| } | |
| def _normalize_templates(templates: list[dict]) -> list[dict]: | |
| """ | |
| Ensure each template has required keys and add detected language. | |
| Input structure (DynamicTemplate): { id, template, platform, brand_voice } | |
| """ | |
| norm = [] | |
| for t in templates: | |
| tid = t.get("id") or t.get("ID") | |
| txt = t.get("template") or t.get("Template") | |
| platform = (t.get("platform") or t.get("Platform") or "").strip() | |
| brand_voice = t.get("brand_voice") or t.get("BrandVoice") or "" | |
| norm.append({ | |
| "id": tid, | |
| "template": txt, | |
| "platform": platform, | |
| "brand_voice": brand_voice, | |
| }) | |
| return norm | |
| def _build_matching_prompt(product: dict, templates10: list[dict]) -> str: | |
| """ | |
| Your exact prompt shape, kept intact (including the code-fenced JSON example). | |
| """ | |
| # product block | |
| product_str = f"""Product: | |
| - id: {product['id']} | |
| - name: {product['name']} | |
| - category: {product['category']} | |
| - type: {product['type']} | |
| - price: {product['price']} | |
| - currency: {product['currency']} | |
| - Description: {product['description']} | |
| - stock_quantity: {product['stock_quantity']} | |
| - sku: {product['sku']} | |
| - options: {product['options']} | |
| - on_sale: {product['on_sale']}""" | |
| # template list (note: keeping "plateform" spelling exactly as your prompt) | |
| template_list = "\n".join([ | |
| f"{i+1}. {t['template']} (id: {t['id']}, plateform: {t['platform']}, brandvoice: {t['brand_voice']})" | |
| for i, t in enumerate(templates10) | |
| ]) | |
| json_example = """```json | |
| [ | |
| { "id": "tpl_005", "score": 0.91 }, | |
| { "id": "tpl_007", "score": 0.85 }, | |
| { "id": "tpl_013", "score": 0.0 } | |
| ] | |
| ```""" | |
| prompt = f""" | |
| You are a multilingual social media strategist. | |
| Your task: | |
| Given a product and a list of 10 candidate social media post templates, score the templates from best to worst match. | |
| Evaluate how well each template fits the product based on: | |
| - Relevance to the product's description and type | |
| - Alignment with the platform and brand voice | |
| - Overall marketing appeal and fluency | |
| {product_str} | |
| Templates: | |
| {template_list} | |
| Instructions: | |
| 1. Analyze all 10 templates. | |
| 2. Return a list of TemplateIDs with a matching score between 0.0 and 1.0. | |
| 3. The higher the score, the better the match. | |
| 4. All 10 templates must appear in the output, even if their score is 0.0. | |
| 5. Output the result as valid JSON inside a single code block, like this: | |
| {json_example} | |
| Now score the templates and return the result which must include the 10 templates with their score . | |
| """ | |
| return prompt.strip() | |
| def preselect_templates(state: Dict[str, Any]) -> Dict[str, Any]: | |
| """Filter templates by platform + language.""" | |
| templates = state["templates"] | |
| platform = state["platform"] | |
| lang = state.get("language", "en") | |
| filtered = [t for t in templates if t["platform"] == platform and t["language"] == lang] | |
| state["candidate_templates"] = filtered | |
| return state | |
| def _extract_json_from_code_block(output_text: str): | |
| import re, json | |
| # Try fenced ```json ... ``` | |
| m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", output_text, re.IGNORECASE) | |
| if m: | |
| candidate = m.group(1).strip() | |
| else: | |
| # Fallback: first JSON-like array | |
| m = re.search(r"(\[\s*\{[\s\S]*?\}\s*\])", output_text) | |
| if not m: | |
| return None | |
| candidate = m.group(1).strip() | |
| candidate = candidate.replace("'", '"') | |
| candidate = candidate.replace("\t", " ") | |
| candidate = candidate.replace("\r", " ") | |
| # remove trailing commas | |
| candidate = re.sub(r",\s*([\]}])", r"\1", candidate) | |
| try: | |
| obj = json.loads(candidate) | |
| if not isinstance(obj, list): | |
| return None | |
| # Normalize keys: accept {"id","score"} or {"template_id","score"} | |
| normalized = [] | |
| for item in obj: | |
| if not isinstance(item, dict): | |
| continue | |
| tid = item.get("id") or item.get("template_id") | |
| sc = item.get("score", 0.0) | |
| if tid is None: | |
| continue | |
| try: | |
| sc = float(sc) | |
| except Exception: | |
| sc = 0.0 | |
| normalized.append({"id": tid, "score": max(0.0, min(1.0, sc))}) | |
| return normalized | |
| except Exception: | |
| return None | |
| def _merge_scores(score_output: list[dict], templates10: list[dict]) -> list[dict]: | |
| # map id->score from LLM | |
| out_map = {s["id"]: s["score"] for s in (score_output or []) if "id" in s} | |
| merged = [] | |
| for t in templates10: | |
| merged.append({ | |
| "id": t["id"], | |
| "template": t["template"], | |
| "platform": t["platform"], | |
| "brand_voice": t["brand_voice"], | |
| "score": float(out_map.get(t["id"], 0.0)) | |
| }) | |
| merged.sort(key=lambda x: x["score"], reverse=True) | |
| return merged | |
| def node_normalize_inputs(state: dict) -> dict: | |
| product = state.get("product", {}) | |
| templates = state.get("templates", []) | |
| platform = state.get("platform", "") | |
| # Normalize | |
| norm_product = _normalize_product(product) | |
| norm_templates = _normalize_templates(templates) | |
| state["product_norm"] = norm_product | |
| state["templates_norm"] = norm_templates | |
| state["platform_norm"] = (platform or "").strip() | |
| return state | |
| def node_preselect_by_platform_and_language(state: dict) -> dict: | |
| from langdetect import detect | |
| product = state["product_norm"] | |
| templates = state["templates_norm"] | |
| platform = state["platform_norm"] | |
| product_lang = detect(f"{product.get('name','')} {product.get('description','')}") | |
| filtered = [ | |
| t for t in templates | |
| if t["platform"].lower() == platform.lower() | |
| and detect(t["template"]) == product_lang | |
| ] | |
| # keep max 10 candidates | |
| state["candidates_10"] = filtered[:10] | |
| state["product_language"] = product_lang | |
| return state | |
| def node_build_matching_prompt(state: dict) -> dict: | |
| product = state["product_norm"] | |
| cands = state["candidates_10"] | |
| prompt = _build_matching_prompt(product, cands) | |
| state["matching_prompt"] = prompt | |
| return state | |
| def node_llm_infer_scores(state: dict) -> dict: | |
| generator = _get_hf_generator_match() | |
| prompt = state["matching_prompt"] | |
| out = generator( | |
| prompt, | |
| max_new_tokens=DEFAULT_CONFIG["matching"]["MAX_NEW_TOKENS"], | |
| temperature=DEFAULT_CONFIG["matching"]["TEMPERATURE"], | |
| top_p=DEFAULT_CONFIG["matching"]["TOP_P"], | |
| do_sample=True, | |
| eos_token_id=None, | |
| ) | |
| # HF pipelines return list of dicts with 'generated_text' | |
| raw_text = out[0]["generated_text"] if isinstance(out, list) else str(out) | |
| # Keep only the part after the prompt if model echoes it | |
| if raw_text.startswith(prompt): | |
| raw_text = raw_text[len(prompt):].strip() | |
| state["llm_raw_output"] = raw_text | |
| return state | |
| def node_parse_and_merge_scores(state: dict) -> dict: | |
| raw = state.get("llm_raw_output", "") | |
| parsed = _extract_json_from_code_block(raw) or [] | |
| state["scores_parsed"] = parsed | |
| merged = _merge_scores(parsed, state["candidates_10"]) | |
| state["ranked_templates"] = merged | |
| return state | |
| def node_finalize_ranked_output(state: dict) -> dict: | |
| k = min(DEFAULT_CONFIG["matching"]["TOP_K_RETURN"], len(state.get("ranked_templates", []))) | |
| state["ranked_templates"] = state["ranked_templates"][:k] | |
| # keep compact debug (helpful later when chaining to generation) | |
| state["debug"] = { | |
| "prompt": state.get("matching_prompt", "")[:4000], | |
| "raw_output": state.get("llm_raw_output", "")[:4000], | |
| "parsed_scores": state.get("scores_parsed", []), | |
| "product_language": state.get("product_language", ""), | |
| } | |
| # Clean large intermediates if you want | |
| return state | |
| def build_matching_graph() -> Any: | |
| graph = StateGraph(dict) | |
| # Add nodes | |
| graph.add_node("normalize_inputs", node_normalize_inputs) | |
| graph.add_node("preselect", node_preselect_by_platform_and_language) | |
| graph.add_node("build_prompt", node_build_matching_prompt) | |
| graph.add_node("infer", node_llm_infer_scores) | |
| graph.add_node("parse_merge", node_parse_and_merge_scores) | |
| graph.add_node("finalize", node_finalize_ranked_output) | |
| # Entry point | |
| graph.set_entry_point("normalize_inputs") | |
| # Edges | |
| graph.add_edge("normalize_inputs", "preselect") | |
| graph.add_edge("preselect", "build_prompt") | |
| graph.add_edge("build_prompt", "infer") | |
| graph.add_edge("infer", "parse_merge") | |
| graph.add_edge("parse_merge", "finalize") | |
| graph.add_edge("finalize", END) # ✅ END is reserved, just link to it | |
| return graph.compile() | |
| # Expose app | |
| matching_app = build_matching_graph() | |
| class PostGenState(TypedDict, total=False): | |
| # Inputs expected from previous step | |
| product: Dict[str, Any] | |
| ranked: List[Dict[str, Any]] # from matching: [{id, template, platform, brand_voice, score}, ...] | |
| platform: str | |
| # Post-gen intermediates | |
| selected_template: Dict[str, Any] | |
| post_prompt: str | |
| post_raw_output: str | |
| post_parsed: Dict[str, Any] | |
| # Final | |
| final_post_struct: Dict[str, Any] | |
| def _get_hf_generator_generator(): | |
| from transformers import pipeline | |
| import torch | |
| global _GENERATOR | |
| if _GENERATOR is not None: | |
| return _GENERATOR | |
| hf_token = DEFAULT_CONFIG["providers"]["hf"]["token_gen"] | |
| if not hf_token: | |
| raise RuntimeError( | |
| "❌ Hugging Face token not found. Please set the environment variable HF_TOKEN in your Space settings." | |
| ) | |
| # dtype selection | |
| if torch.cuda.is_available(): | |
| major, _ = torch.cuda.get_device_capability() | |
| torch_dtype = torch.bfloat16 if major >= 8 else torch.float16 | |
| else: | |
| torch_dtype = torch.float32 | |
| try: | |
| _GENERATOR = pipeline( | |
| "text-generation", | |
| model=DEFAULT_CONFIG["postgen"]["MODEL_NAME"], # ✅ fixed typo | |
| device_map=DEFAULT_CONFIG["postgen"]["HF_DEVICE_MAP"], | |
| torch_dtype=torch_dtype, | |
| token=hf_token, # ✅ uses safe env token | |
| ) | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"❌ Failed to load model `{DEFAULT_CONFIG['postgen']['MODEL_NAME']}`. " | |
| "If it's a gated repo, request access and ensure your HF token has permission. " | |
| f"Original error: {e}" | |
| ) | |
| return _GENERATOR | |
| def build_post_generation_prompt(product, template): | |
| import json | |
| # --- few-shot examples (same as fine-tuning) --- | |
| few1_product = { | |
| "name": "Herbal Glow Organic Shampoo", | |
| "category": "Hair Care", | |
| "type": "Shampoo", | |
| "price": 14.99, | |
| "currency": "USD", | |
| "description": "Nourishing shampoo made with organic argan oil for smooth, shiny hair.", | |
| "on_sale": True, | |
| "options": [{"name": "Size", "value": "250ml"}] | |
| } | |
| few1_template = { | |
| "template": "Say goodbye to dull hair! 🌿 [PRODUCT_NAME] is your go-to [CATEGORY] for silky smooth results — now only [PRICE] [CURRENCY]!", | |
| "score": 0.88, | |
| "platform": "Instagram", | |
| "brand_voice": "Natural & Friendly" | |
| } | |
| few1_output = { | |
| "text": "Say goodbye to dull hair! 🌿 Herbal Glow Organic Shampoo is your go-to hair care for silky smooth results — now only 14.99 USD! 💆♀️✨ #HealthyHair #OrganicBeauty", | |
| "score": 0.95, | |
| "confidence_breakdown": {"brand_alignment": 0.96, "template_match": 0.88, "clarity_persuasiveness": 0.97} | |
| } | |
| few2_product = { | |
| "name": "Montre Élégance Argentée", | |
| "category": "Accessoires", | |
| "type": "Montre", | |
| "price": 129.90, | |
| "currency": "EUR", | |
| "description": "Montre en acier inoxydable, design raffiné pour toutes les occasions.", | |
| "on_sale": False, | |
| "options": [{"name": "Couleur", "value": "Argent"}] | |
| } | |
| few2_template = { | |
| "template": "Découvrez [PRODUCT_NAME] — l’[CATEGORY] parfaite pour sublimer votre style. Prix : [PRICE] [CURRENCY].", | |
| "score": 0.91, | |
| "platform": "LinkedIn", | |
| "brand_voice": "Luxueux et professionnel" | |
| } | |
| few2_output = { | |
| "text": "Découvrez Montre Élégance Argentée — l’accessoire parfait pour sublimer votre style ✨. Prix : 129,90 €. Conçue pour les esprits raffinés et les occasions d’exception. #MontresDeLuxe #Élégance", | |
| "score": 0.93, | |
| "confidence_breakdown": {"brand_alignment": 0.94, "template_match": 0.91, "clarity_persuasiveness": 0.94} | |
| } | |
| instructions = """ | |
| You are an expert social-media copywriter AND a marketing evaluator. | |
| TASK: | |
| - Replace placeholders in the template (e.g. [PRODUCT_NAME], [CATEGORY], [TYPE], [PRICE], [CURRENCY], [OPTION_VALUE]) with the exact values from the PRODUCT object. | |
| - Produce a single, ready-to-post marketing text adapted to: | |
| * the template structure and placeholders, | |
| * the template.brand_voice (tone & vocabulary), | |
| * the template.platform (platform-specific style rules below), | |
| * the product data (use options, on_sale, etc. when relevant). | |
| - Add emojis and 1–5 hashtags consistent with product, platform, and brand voice. | |
| - If product.on_sale is True, mention the deal naturally (if it fits the template). | |
| - Keep language consistent with the template language (if template is French → output in French). | |
| PLATFORM GUIDELINES (apply strictly): | |
| - Instagram: eye-catching, up to 5 hashtags, emojis welcome, slightly conversational. | |
| - TikTok: short, energetic, 1–3 hashtags, call-to-action possible (e.g., "link in bio"), emojis welcome. | |
| - Facebook: friendly, slightly longer allowed, 1–2 hashtags, 0–2 emojis. | |
| - X/Twitter: concise (short sentence), 0–2 hashtags, 0–1 emoji. | |
| - LinkedIn: professional, minimal emojis (0–1), 0–2 hashtags, formal vocabulary. | |
| - Pinterest: descriptive with keywords/hashtags, minimal emojis. | |
| SCORING RULE (how to compute final score): | |
| - brand_alignment = how well tone/emoji/hashtags match template.brand_voice & platform (0.0–1.0). | |
| - template_match = use template['score'] (0.0–1.0) — this reflects semantic match. | |
| - clarity_persuasiveness = how clear, persuasive, and well-structured the post is (0.0–1.0). | |
| - FINAL self_confidence_score = average(brand_alignment, template_match, clarity_persuasiveness). Round to two decimals. | |
| OUTPUT FORMAT (exact — NO extra text, no JSON wrappers, no commentary): | |
| text: "<final post text>" | |
| score: <0.00-1.00> | |
| confidence_breakdown: {"brand_alignment":X, "template_match":Y, "clarity_persuasiveness":Z} | |
| (Use dot as decimal separator for scores; keep post language as required.) | |
| """ | |
| prompt = ( | |
| instructions.strip() + "\n\n" | |
| "FEW-SHOT EXAMPLES\n\n" | |
| "Example 1 INPUT:\nPRODUCT:\n" + json.dumps(few1_product, ensure_ascii=False) + "\nTEMPLATE:\n" + json.dumps(few1_template, ensure_ascii=False) + "\n\n" | |
| "Example 1 OUTPUT:\ntext: " + json.dumps(few1_output["text"], ensure_ascii=False) + "\n" | |
| f"score: {few1_output['score']:.2f}\n" | |
| "confidence_breakdown: " + json.dumps(few1_output["confidence_breakdown"], ensure_ascii=False) + "\n\n" | |
| "Example 2 INPUT:\nPRODUCT:\n" + json.dumps(few2_product, ensure_ascii=False) + "\nTEMPLATE:\n" + json.dumps(few2_template, ensure_ascii=False) + "\n\n" | |
| "Example 2 OUTPUT:\ntext: " + json.dumps(few2_output["text"], ensure_ascii=False) + "\n" | |
| f"score: {few2_output['score']:.2f}\n" | |
| "confidence_breakdown: " + json.dumps(few2_output["confidence_breakdown"], ensure_ascii=False) + "\n\n" | |
| "NOW PROCESS THE NEW INPUT\n\n" | |
| "INPUT PRODUCT:\n" + json.dumps(product, ensure_ascii=False) + "\n\n" | |
| "INPUT TEMPLATE:\n" + json.dumps(template, ensure_ascii=False) + "\n\n" | |
| "OUTPUT:\n" | |
| ) | |
| return prompt.strip() | |
| def _strip_code_fences(s: str) -> str: | |
| m = _CODEFence_RE.search(s) | |
| return m.group(1).strip() if m else s | |
| def _safe_json_loads(s: str) -> Optional[dict]: | |
| try: | |
| return json.loads(s) | |
| except Exception: | |
| # try common cleanups | |
| s2 = s.replace("“", '"').replace("”", '"').replace("’", "'").replace("‘", "'") | |
| s2 = re.sub(r",\s*(\}|\])", r"\1", s2) # remove trailing commas | |
| s2 = s2.replace("'", '"') | |
| try: | |
| return json.loads(s2) | |
| except Exception: | |
| return None | |
| def parse_post_output_llm(raw: str) -> Dict[str, Any]: | |
| """ | |
| Expected LLM format (from your prompt): | |
| text: "<final post text>" | |
| score: <0.00-1.00> | |
| confidence_breakdown: {"brand_alignment":X, "template_match":Y, "clarity_persuasiveness":Z} | |
| Returns dict with keys: text, score, confidence_breakdown (values may be None if missing). | |
| """ | |
| txt = _strip_code_fences(raw) | |
| # text (quoted) | |
| text_match = re.search(r'text:\s*"(.*?)"', txt, flags=re.DOTALL) | |
| final_text = text_match.group(1).strip() if text_match else None | |
| # score (float) | |
| score_match = re.search(r'score:\s*([01]?(?:\.\d+)?|\d\.\d+)', txt) | |
| score_val = float(score_match.group(1)) if score_match else None | |
| # confidence_breakdown (JSON-ish dict) | |
| brk_match = re.search(r'confidence_breakdown:\s*(\{[\s\S]*?\})', txt) | |
| breakdown = _safe_json_loads(brk_match.group(1)) if brk_match else None | |
| breakdown = breakdown if isinstance(breakdown, dict) else {} | |
| clean_breakdown = { | |
| "brand_alignment": breakdown.get("brand_alignment", None), | |
| "template_match": breakdown.get("template_match", None), | |
| "clarity_persuasiveness": breakdown.get("clarity_persuasiveness", None), | |
| } | |
| return { | |
| "text": final_text, | |
| "score": score_val, | |
| "confidence_breakdown": clean_breakdown, | |
| } | |
| def node_select_top_template(state: PostGenState) -> PostGenState: | |
| ranked = state.get("ranked", []) | |
| if not ranked: | |
| raise ValueError("PostGen: 'ranked' list is empty or missing.") | |
| # choose highest score (even if input already sorted) | |
| best = sorted(ranked, key=lambda x: x.get("score", 0.0), reverse=True)[0] | |
| return {**state, "selected_template": best} | |
| def node_build_post_prompt(state: PostGenState) -> PostGenState: | |
| product = state["product"] | |
| template = state["selected_template"] | |
| prompt = build_post_generation_prompt(product, template) | |
| return {**state, "post_prompt": prompt} | |
| def node_generate_post_llm(state: PostGenState) -> PostGenState: | |
| generator = _get_hf_generator_generator() | |
| prompt = state["post_prompt"] | |
| out = generator( | |
| prompt, | |
| max_new_tokens=DEFAULT_CONFIG["postgen"]["MAX_NEW_TOKENS"], | |
| do_sample=True, | |
| temperature=DEFAULT_CONFIG["postgen"]["TEMPERATURE"], | |
| top_p=DEFAULT_CONFIG["postgen"]["TOP_P"], | |
| return_full_text=False, | |
| ) | |
| raw = out[0]["generated_text"] if isinstance(out, list) and out else str(out) | |
| return {**state, "post_raw_output": raw} | |
| def node_parse_post_output(state: PostGenState) -> PostGenState: | |
| raw = state["post_raw_output"] | |
| parsed = parse_post_output_llm(raw) | |
| return {**state, "post_parsed": parsed} | |
| def node_merge_post_struct(state: PostGenState) -> PostGenState: | |
| product = state["product"] | |
| template = state["selected_template"] | |
| parsed = state["post_parsed"] | |
| final_struct = { | |
| # IDs come from inputs (NOT from LLM) | |
| "product_id": product.get("id"), | |
| "template_id": template.get("id"), | |
| # LLM-derived | |
| "final_post": parsed.get("text"), | |
| "self_confidence_score": parsed.get("score"), | |
| "confidence_breakdown": parsed.get("confidence_breakdown"), | |
| } | |
| return {**state, "final_post_struct": final_struct} | |
| def build_post_generation_graph(): | |
| g = StateGraph(PostGenState) | |
| g.add_node("select_top_template", node_select_top_template) | |
| g.add_node("build_prompt", node_build_post_prompt) | |
| g.add_node("generate_post", node_generate_post_llm) | |
| g.add_node("parse_output", node_parse_post_output) | |
| g.add_node("merge_struct", node_merge_post_struct) | |
| g.set_entry_point("select_top_template") | |
| g.add_edge("select_top_template", "build_prompt") | |
| g.add_edge("build_prompt", "generate_post") | |
| g.add_edge("generate_post", "parse_output") | |
| g.add_edge("parse_output", "merge_struct") | |
| g.add_edge("merge_struct", END) | |
| return g.compile() | |
| postgen_app=build_post_generation_graph() | |
| class PostScheduler: | |
| def __init__(self, rules_file, timezone_offset=0): | |
| with open(rules_file, "r") as f: | |
| self.rules = json.load(f) | |
| self.timezone_offset = timezone_offset | |
| def get_schedule(self, category, platform): | |
| category = category.lower() | |
| platform = platform.lower() | |
| cat_rules = self.rules.get(category, {}) | |
| default_rules = self.rules.get("default", {}) | |
| if platform in cat_rules: | |
| slots = cat_rules[platform] | |
| elif platform in default_rules: | |
| slots = default_rules[platform] | |
| else: | |
| raise ValueError(f"No scheduling rules for {category} or default / {platform}") | |
| normalized = [] | |
| for slot in slots: | |
| expanded = self.normalize_slot(slot, platform, default_rules) | |
| normalized.extend(expanded) | |
| if not normalized: | |
| # fallback: post tomorrow at 09:00 | |
| scheduled_datetime = datetime.now().replace(hour=9, minute=0, second=0, microsecond=0) + timedelta(days=1) | |
| return scheduled_datetime.strftime("%Y-%m-%d %H:%M") | |
| selected_slot = random.choice(normalized) | |
| scheduled_datetime = self._parse_slot_to_datetime(selected_slot) | |
| return scheduled_datetime.strftime("%Y-%m-%d %H:%M") | |
| def normalize_slot(self, slot: str, platform: str, default_rules: dict) -> list[str]: | |
| slot = slot.strip().lower() | |
| days_map = { | |
| "weekdays": ["monday","tuesday","wednesday","thursday","friday"], | |
| "weekend": ["saturday","sunday"] | |
| } | |
| if "platform default" in slot: | |
| return default_rules.get(platform, []) or [] | |
| if "weekdays" in slot: | |
| time = slot.split()[0] | |
| return [f"{time} {day}" for day in days_map["weekdays"]] | |
| if "&" in slot: | |
| time, days = slot.split(" ", 1) | |
| expanded_days = [d.strip() for d in days.split("&")] | |
| return [f"{time} {d}" for d in expanded_days] | |
| return [slot] | |
| def _parse_slot_to_datetime(self, slot: str) -> datetime: | |
| now = datetime.now() | |
| slot = slot.strip() | |
| time_part = slot.split(" ")[0] | |
| if "-" in time_part and ":" in time_part: | |
| start_time = time_part.split("-")[0] | |
| else: | |
| start_time = time_part | |
| match = re.match(r"(\d{1,2}):(\d{2})", start_time) | |
| if not match: | |
| raise ValueError(f"Invalid time format in slot: {slot}") | |
| hour, minute = map(int, match.groups()) | |
| scheduled = now.replace(hour=hour, minute=minute, second=0, microsecond=0) | |
| scheduled += timedelta(hours=self.timezone_offset) | |
| if scheduled <= now: | |
| scheduled += timedelta(days=1) | |
| return scheduled | |
| class SchedulingState(TypedDict, total=False): | |
| product: Dict[str, Any] | |
| platform: str | |
| final_post_struct: Dict[str, Any] # re-use directly | |
| scheduled_post: Dict[str, Any] | |
| from typing import Any, Dict, List, TypedDict | |
| class GlobalState(TypedDict, total=False): | |
| # Matching inputs | |
| product: Dict[str, Any] | |
| platform: str | |
| templates: List[Dict[str, Any]] | |
| # Matching outputs | |
| ranked_templates: List[Dict[str, Any]] | |
| # PostGen outputs | |
| final_post_struct: Dict[str, Any] # product_id, template_id, post text | |
| # Scheduling outputs | |
| scheduled_post: Dict[str, Any] | |
| def matching_node(state: dict) -> dict: | |
| """Run Matching subgraph inside global pipeline.""" | |
| result = matching_app.invoke({ | |
| "product": state["product"], | |
| "platform": state["platform"], | |
| "templates": state["templates"], | |
| "candidate_templates": [], | |
| "top_k": 10 | |
| }) | |
| state["ranked_templates"] = result["ranked_templates"] | |
| return state | |
| def prepare_for_postgen(state: GlobalState) -> PostGenState: | |
| """Adapt Matching output to PostGen input format""" | |
| return { | |
| "product": state["product"], | |
| "ranked": state.get("ranked_templates", []), | |
| "platform": state["platform"] | |
| } | |
| def postgen_node(state: GlobalState) -> dict: | |
| """Run Post Generation subgraph inside global pipeline.""" | |
| result = postgen_app.invoke({ | |
| "product": state["product"], | |
| "ranked": state["ranked_templates"], | |
| "platform": state["platform"] | |
| }) | |
| state["final_post_struct"] = result["final_post_struct"] | |
| return state | |
| def prepare_for_scheduling(state: GlobalState) -> SchedulingState: | |
| return { | |
| "product": state["product"], | |
| "platform": state["platform"], | |
| "final_post_struct": state["final_post_struct"], # no renaming | |
| "scheduled_post": {} | |
| } | |
| def scheduling_node(state: SchedulingState) -> SchedulingState: | |
| product = state["product"] | |
| platform = state["platform"] | |
| final_post_struct = state["final_post_struct"] | |
| category = product.get("Category") | |
| scheduler = PostScheduler(rules_file=DEFAULT_CONFIG["scheduling"]["rules_file"]) | |
| scheduled_time = scheduler.get_schedule(category, platform) | |
| state["scheduled_post"] = { | |
| **final_post_struct, | |
| "scheduled_time": scheduled_time, | |
| } | |
| return state | |
| def build_global_graph(): | |
| g = StateGraph(GlobalState) | |
| # Nodes | |
| g.add_node("matching", matching_node) | |
| g.add_node("prepare_for_postgen", prepare_for_postgen) | |
| g.add_node("postgen", postgen_node) | |
| g.add_node("prepare_for_scheduling", prepare_for_scheduling) | |
| g.add_node("scheduling", scheduling_node) | |
| # Flow | |
| g.set_entry_point("matching") | |
| g.add_edge("matching", "prepare_for_postgen") | |
| g.add_edge("prepare_for_postgen", "postgen") | |
| g.add_edge("postgen","prepare_for_scheduling" ) | |
| g.add_edge("prepare_for_scheduling", "scheduling") | |
| g.add_edge("scheduling", END) | |
| return g.compile() |