MariamBM commited on
Commit
8db7b5b
·
verified ·
1 Parent(s): 5a9584e

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +122 -0
  2. copilotpy.py +841 -0
  3. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ from copilot import build_global_graph
4
+
5
+ # Build global graph once
6
+ global_app = build_global_graph()
7
+
8
+ st.set_page_config(
9
+ page_title="AI Marketing Copilot",
10
+ page_icon="🤖",
11
+ layout="centered",
12
+ initial_sidebar_state="collapsed",
13
+ )
14
+
15
+ # --- HEADER ---
16
+ st.image("logo.png", width=150) # <- replace with your project logo
17
+ st.title("AI Marketing Copilot")
18
+ st.markdown("Your intelligent assistant for **post generation & scheduling** ✨")
19
+
20
+ st.divider()
21
+
22
+ # --- PRODUCT INPUT FORM ---
23
+ st.subheader("🛒 Product Information")
24
+
25
+ with st.form("product_form"):
26
+ # Required fields
27
+ product_id = st.text_input("Product ID *", "PEN0001")
28
+ product_name = st.text_input("Product Name *", "EcoWave Stainless Steel Insulated Bottle")
29
+ product_category = st.text_input("Category *", "Drinkware")
30
+ product_description = st.text_area(
31
+ "Description *",
32
+ "Durable, eco-friendly insulated bottle for everyday use."
33
+ )
34
+
35
+ # Optional fields in expander
36
+ with st.expander("🔧 Advanced fields (optional)"):
37
+ product_type = st.text_input("Type", "Bottle")
38
+ product_price = st.text_input("Price", "24.99")
39
+ product_currency = st.text_input("Currency", "USD")
40
+ product_stock = st.number_input("Stock Quantity", value=42, step=1)
41
+ product_sku = st.text_input("SKU", "ECO-SS-500")
42
+ product_options = st.text_area(
43
+ "Options (JSON list)",
44
+ '[{"name": "Size", "value": "500ml"}]'
45
+ )
46
+ product_on_sale = st.checkbox("On Sale?", value=True)
47
+
48
+ # Platform selection
49
+ platform = st.selectbox("📱 Target Platform *", ["Instagram", "Twitter", "Facebook", "LinkedIn", "TikTok"])
50
+
51
+ submitted = st.form_submit_button("🚀 Generate & Schedule")
52
+
53
+ if submitted:
54
+ try:
55
+ # Parse options safely
56
+ try:
57
+ options = json.loads(product_options)
58
+ if not isinstance(options, list):
59
+ options = []
60
+ except Exception:
61
+ options = []
62
+
63
+ # Build product dict
64
+ product = {
65
+ "id": product_id,
66
+ "name": product_name,
67
+ "category": product_category,
68
+ "type": product_type,
69
+ "price": product_price,
70
+ "currency": product_currency,
71
+ "description": product_description,
72
+ "stock_quantity": product_stock,
73
+ "sku": product_sku,
74
+ "images": [],
75
+ "options": options,
76
+ "on_sale": product_on_sale,
77
+ }
78
+
79
+ # Templates placeholder (normally loaded from DB)
80
+ templates = []
81
+
82
+ with st.spinner("🤖 Generating post and scheduling..."):
83
+ state = {
84
+ "product": product,
85
+ "platform": platform,
86
+ "templates": templates,
87
+ }
88
+ result = global_app.invoke(state)
89
+
90
+ st.success("✅ Post Generated & Scheduled!")
91
+
92
+ # --- MAIN OUTPUT CARD ---
93
+ st.subheader("📢 Final Post")
94
+ final_post = result.get("final_post_struct", {}).get("post_text", "⚠️ No post generated")
95
+ st.markdown(
96
+ f"""
97
+ <div style="padding:1.2em; border-radius:10px; background-color:#F0F9FF; border:1px solid #90CAF9;">
98
+ <p style="font-size:1.1em;">{final_post}</p>
99
+ </div>
100
+ """,
101
+ unsafe_allow_html=True
102
+ )
103
+
104
+ # Scheduled time
105
+ st.subheader("⏰ Scheduled Time")
106
+ st.info(result.get("scheduled_post", {}).get("scheduled_time", "Not scheduled"))
107
+
108
+ # --- EXTRA: Ranked Templates ---
109
+ if st.button("📊 Show Template Rankings"):
110
+ ranked = result.get("ranked_templates", [])
111
+ if ranked:
112
+ st.markdown("### Template Scores")
113
+ st.dataframe(ranked)
114
+ else:
115
+ st.warning("No ranked templates available.")
116
+
117
+ except Exception as e:
118
+ st.error(f"An error occurred: {e}")
119
+
120
+ # --- FOOTER ---
121
+ st.divider()
122
+ st.caption("⚡ Powered by LangGraph + Hugging Face Spaces")
copilotpy.py ADDED
@@ -0,0 +1,841 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import random
5
+ import logging
6
+ import torch
7
+ import yaml
8
+
9
+ from datetime import datetime, timedelta
10
+ from typing import Any, Dict, List, Optional, TypedDict
11
+
12
+ from dotenv import load_dotenv
13
+ from langgraph.graph import StateGraph, END
14
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress TF logs
15
+
16
+ _GENERATOR = None
17
+ _CODEFence_RE = re.compile(r"```(?:json)?\s*([\s\S]*?)\s*```", re.IGNORECASE)
18
+
19
+
20
+ DEFAULT_CONFIG = {
21
+ "matching": {
22
+ "MODEL_NAME": "mistralai/Mistral-7B-Instruct-v0.2",
23
+ "HF_DEVICE_MAP": "auto",
24
+ "MAX_NEW_TOKENS": 512,
25
+ "TEMPERATURE": 0.2,
26
+ "TOP_P": 0.9,
27
+ "TOP_K_RETURN": 10,
28
+ },
29
+ "postgen": {
30
+ "MODEL_NAME": "mistralai/Mistral-7B-Instruct-v0.1",
31
+ "HF_DEVICE_MAP": "auto",
32
+ "MAX_NEW_TOKENS": 512,
33
+ "TEMPERATURE": 0.2,
34
+ "TOP_P": 0.9,
35
+ },
36
+ "scheduling": {
37
+ "rules_file": "./rule_based_scheduling_data.json",
38
+ "timezone_offset": 0
39
+ },
40
+ "providers": {
41
+ "hf": {
42
+ "token_matching": os.getenv("mistralcopilothf"),
43
+ "token_gen": os.getenv("mistralcopilothf"),
44
+ }
45
+ }
46
+ }
47
+
48
+
49
+ def _get_hf_generator_match():
50
+ """
51
+ Create (once) a Hugging Face text-generation pipeline for Mistral.
52
+ Model-only (no mock). Raises if token/gated repo issues occur.
53
+ """
54
+ global _GENERATOR
55
+ if _GENERATOR is not None:
56
+ return _GENERATOR
57
+
58
+ import os
59
+ import torch
60
+ from transformers import pipeline
61
+
62
+ token = DEFAULT_CONFIG["providers"]["hf"]["token_matching"]
63
+ if not token:
64
+ raise RuntimeError(
65
+ "Hugging Face token not found. Set env var HUGGINGFACE_TOKEN (or HF_TOKEN)."
66
+ )
67
+
68
+
69
+ # dtype selection
70
+ if torch.cuda.is_available():
71
+ major, _ = torch.cuda.get_device_capability()
72
+ torch_dtype = torch.bfloat16 if major >= 8 else torch.float16
73
+ else:
74
+ torch_dtype = torch.float32
75
+
76
+ try:
77
+ _GENERATOR = pipeline(
78
+ "text-generation",
79
+ model=DEFAULT_CONFIG["matching"]["MODEL_NAME"],
80
+ device_map=DEFAULT_CONFIG["matching"]["HF_DEVICE_MAP"],
81
+ torch_dtype=torch_dtype,
82
+ token=token,
83
+ )
84
+ except Exception as e:
85
+ # Surface helpful error if gated
86
+ raise RuntimeError(
87
+ f"Failed to load model . "
88
+ "If it's a gated repo, request access and ensure your token has it. "
89
+ f"Original error: {e}"
90
+ )
91
+
92
+ return _GENERATOR
93
+
94
+ def _normalize_product(p: dict) -> dict:
95
+ """
96
+ Accept product with either Go-style TitleCase or pythonic snake/camel.
97
+ Return a normalized dict with lowercase keys used by the prompt.
98
+ """
99
+ # handle multiple possible casings
100
+ def g(k):
101
+ return (
102
+ p.get(k)
103
+ or p.get(k.lower())
104
+ or p.get(k.capitalize())
105
+ or p.get(k.replace("_", ""))
106
+ or p.get(k.upper())
107
+ )
108
+ # Options should be list of {"name":..., "value":...}
109
+ options = g("Options") or g("options") or []
110
+ # cast price to string (your Go struct has string price)
111
+ price_val = g("Price")
112
+ if isinstance(price_val, (int, float)):
113
+ price_val = f"{price_val:.2f}"
114
+ return {
115
+ "id": g("ID") or g("Id") or g("id"),
116
+ "name": g("Name") or g("name"),
117
+ "category": g("Category") or g("category"),
118
+ "type": g("Type") or g("type"),
119
+ "price": price_val or "",
120
+ "currency": g("Currency") or g("currency") or "",
121
+ "description": g("Description") or g("description") or "",
122
+ "stock_quantity": g("StockQuantity") or g("stock_quantity") or 0,
123
+ "sku": g("SKU") or g("Sku") or g("sku") or "",
124
+ "images": g("Images") or g("images") or [],
125
+ "options": options,
126
+ "on_sale": bool(g("OnSale") if g("OnSale") is not None else g("on_sale") or False),
127
+ }
128
+
129
+ def _normalize_templates(templates: list[dict]) -> list[dict]:
130
+ """
131
+ Ensure each template has required keys and add detected language.
132
+ Input structure (DynamicTemplate): { id, template, platform, brand_voice }
133
+ """
134
+ norm = []
135
+ for t in templates:
136
+ tid = t.get("id") or t.get("ID")
137
+ txt = t.get("template") or t.get("Template")
138
+ platform = (t.get("platform") or t.get("Platform") or "").strip()
139
+ brand_voice = t.get("brand_voice") or t.get("BrandVoice") or ""
140
+ norm.append({
141
+ "id": tid,
142
+ "template": txt,
143
+ "platform": platform,
144
+ "brand_voice": brand_voice,
145
+ })
146
+ return norm
147
+
148
+ def _build_matching_prompt(product: dict, templates10: list[dict]) -> str:
149
+ """
150
+ Your exact prompt shape, kept intact (including the code-fenced JSON example).
151
+ """
152
+ # product block
153
+ product_str = f"""Product:
154
+ - id: {product['id']}
155
+ - name: {product['name']}
156
+ - category: {product['category']}
157
+ - type: {product['type']}
158
+ - price: {product['price']}
159
+ - currency: {product['currency']}
160
+ - Description: {product['description']}
161
+ - stock_quantity: {product['stock_quantity']}
162
+ - sku: {product['sku']}
163
+ - options: {product['options']}
164
+ - on_sale: {product['on_sale']}"""
165
+
166
+ # template list (note: keeping "plateform" spelling exactly as your prompt)
167
+ template_list = "\n".join([
168
+ f"{i+1}. {t['template']} (id: {t['id']}, plateform: {t['platform']}, brandvoice: {t['brand_voice']})"
169
+ for i, t in enumerate(templates10)
170
+ ])
171
+
172
+ json_example = """```json
173
+ [
174
+ { "id": "tpl_005", "score": 0.91 },
175
+ { "id": "tpl_007", "score": 0.85 },
176
+ { "id": "tpl_013", "score": 0.0 }
177
+ ]
178
+ ```"""
179
+
180
+ prompt = f"""
181
+ You are a multilingual social media strategist.
182
+
183
+ Your task:
184
+ Given a product and a list of 10 candidate social media post templates, score the templates from best to worst match.
185
+
186
+ Evaluate how well each template fits the product based on:
187
+ - Relevance to the product's description and type
188
+ - Alignment with the platform and brand voice
189
+ - Overall marketing appeal and fluency
190
+
191
+ {product_str}
192
+
193
+ Templates:
194
+ {template_list}
195
+
196
+ Instructions:
197
+ 1. Analyze all 10 templates.
198
+ 2. Return a list of TemplateIDs with a matching score between 0.0 and 1.0.
199
+ 3. The higher the score, the better the match.
200
+ 4. All 10 templates must appear in the output, even if their score is 0.0.
201
+ 5. Output the result as valid JSON inside a single code block, like this:
202
+
203
+ {json_example}
204
+
205
+ Now score the templates and return the result which must include the 10 templates with their score .
206
+ """
207
+ return prompt.strip()
208
+
209
+
210
+ def preselect_templates(state: Dict[str, Any]) -> Dict[str, Any]:
211
+ """Filter templates by platform + language."""
212
+ templates = state["templates"]
213
+ platform = state["platform"]
214
+ lang = state.get("language", "en")
215
+ filtered = [t for t in templates if t["platform"] == platform and t["language"] == lang]
216
+ state["candidate_templates"] = filtered
217
+ return state
218
+
219
+ def _extract_json_from_code_block(output_text: str):
220
+ import re, json
221
+ # Try fenced ```json ... ```
222
+ m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", output_text, re.IGNORECASE)
223
+ if m:
224
+ candidate = m.group(1).strip()
225
+ else:
226
+ # Fallback: first JSON-like array
227
+ m = re.search(r"(\[\s*\{[\s\S]*?\}\s*\])", output_text)
228
+ if not m:
229
+ return None
230
+ candidate = m.group(1).strip()
231
+
232
+ candidate = candidate.replace("'", '"')
233
+ candidate = candidate.replace("\t", " ")
234
+ candidate = candidate.replace("\r", " ")
235
+ # remove trailing commas
236
+ candidate = re.sub(r",\s*([\]}])", r"\1", candidate)
237
+
238
+ try:
239
+ obj = json.loads(candidate)
240
+ if not isinstance(obj, list):
241
+ return None
242
+ # Normalize keys: accept {"id","score"} or {"template_id","score"}
243
+ normalized = []
244
+ for item in obj:
245
+ if not isinstance(item, dict):
246
+ continue
247
+ tid = item.get("id") or item.get("template_id")
248
+ sc = item.get("score", 0.0)
249
+ if tid is None:
250
+ continue
251
+ try:
252
+ sc = float(sc)
253
+ except Exception:
254
+ sc = 0.0
255
+ normalized.append({"id": tid, "score": max(0.0, min(1.0, sc))})
256
+ return normalized
257
+ except Exception:
258
+ return None
259
+
260
+ def _merge_scores(score_output: list[dict], templates10: list[dict]) -> list[dict]:
261
+ # map id->score from LLM
262
+ out_map = {s["id"]: s["score"] for s in (score_output or []) if "id" in s}
263
+ merged = []
264
+ for t in templates10:
265
+ merged.append({
266
+ "id": t["id"],
267
+ "template": t["template"],
268
+ "platform": t["platform"],
269
+ "brand_voice": t["brand_voice"],
270
+ "score": float(out_map.get(t["id"], 0.0))
271
+ })
272
+ merged.sort(key=lambda x: x["score"], reverse=True)
273
+ return merged
274
+
275
+ def node_normalize_inputs(state: dict) -> dict:
276
+ product = state.get("product", {})
277
+ templates = state.get("templates", [])
278
+ platform = state.get("platform", "")
279
+ # Normalize
280
+ norm_product = _normalize_product(product)
281
+ norm_templates = _normalize_templates(templates)
282
+ state["product_norm"] = norm_product
283
+ state["templates_norm"] = norm_templates
284
+ state["platform_norm"] = (platform or "").strip()
285
+ return state
286
+
287
+ def node_preselect_by_platform_and_language(state: dict) -> dict:
288
+ from langdetect import detect
289
+ product = state["product_norm"]
290
+ templates = state["templates_norm"]
291
+ platform = state["platform_norm"]
292
+ product_lang = detect(f"{product.get('name','')} {product.get('description','')}")
293
+
294
+ filtered = [
295
+ t for t in templates
296
+ if t["platform"].lower() == platform.lower()
297
+ and detect(t["template"]) == product_lang
298
+ ]
299
+
300
+ # keep max 10 candidates
301
+ state["candidates_10"] = filtered[:10]
302
+ state["product_language"] = product_lang
303
+ return state
304
+
305
+ def node_build_matching_prompt(state: dict) -> dict:
306
+ product = state["product_norm"]
307
+ cands = state["candidates_10"]
308
+ prompt = _build_matching_prompt(product, cands)
309
+ state["matching_prompt"] = prompt
310
+ return state
311
+
312
+ def node_llm_infer_scores(state: dict) -> dict:
313
+ generator = _get_hf_generator_match()
314
+ prompt = state["matching_prompt"]
315
+
316
+ out = generator(
317
+ prompt,
318
+ max_new_tokens=DEFAULT_CONFIG["matching"]["MAX_NEW_TOKENS"],
319
+ temperature=DEFAULT_CONFIG["matching"]["TEMPERATURE"],
320
+ top_p=DEFAULT_CONFIG["matching"]["TOP_P"],
321
+ do_sample=True,
322
+ eos_token_id=None,
323
+ )
324
+ # HF pipelines return list of dicts with 'generated_text'
325
+ raw_text = out[0]["generated_text"] if isinstance(out, list) else str(out)
326
+ # Keep only the part after the prompt if model echoes it
327
+ if raw_text.startswith(prompt):
328
+ raw_text = raw_text[len(prompt):].strip()
329
+ state["llm_raw_output"] = raw_text
330
+ return state
331
+
332
+ def node_parse_and_merge_scores(state: dict) -> dict:
333
+ raw = state.get("llm_raw_output", "")
334
+ parsed = _extract_json_from_code_block(raw) or []
335
+ state["scores_parsed"] = parsed
336
+ merged = _merge_scores(parsed, state["candidates_10"])
337
+ state["ranked_templates"] = merged
338
+ return state
339
+
340
+ def node_finalize_ranked_output(state: dict) -> dict:
341
+ k = min(DEFAULT_CONFIG["matching"]["TOP_K_RETURN"], len(state.get("ranked_templates", [])))
342
+ state["ranked_templates"] = state["ranked_templates"][:k]
343
+ # keep compact debug (helpful later when chaining to generation)
344
+ state["debug"] = {
345
+ "prompt": state.get("matching_prompt", "")[:4000],
346
+ "raw_output": state.get("llm_raw_output", "")[:4000],
347
+ "parsed_scores": state.get("scores_parsed", []),
348
+ "product_language": state.get("product_language", ""),
349
+ }
350
+ # Clean large intermediates if you want
351
+ return state
352
+
353
+
354
+ def build_matching_graph() -> Any:
355
+ graph = StateGraph(dict)
356
+
357
+ # Add nodes
358
+ graph.add_node("normalize_inputs", node_normalize_inputs)
359
+ graph.add_node("preselect", node_preselect_by_platform_and_language)
360
+ graph.add_node("build_prompt", node_build_matching_prompt)
361
+ graph.add_node("infer", node_llm_infer_scores)
362
+ graph.add_node("parse_merge", node_parse_and_merge_scores)
363
+ graph.add_node("finalize", node_finalize_ranked_output)
364
+
365
+ # Entry point
366
+ graph.set_entry_point("normalize_inputs")
367
+
368
+ # Edges
369
+ graph.add_edge("normalize_inputs", "preselect")
370
+ graph.add_edge("preselect", "build_prompt")
371
+ graph.add_edge("build_prompt", "infer")
372
+ graph.add_edge("infer", "parse_merge")
373
+ graph.add_edge("parse_merge", "finalize")
374
+ graph.add_edge("finalize", END) # ✅ END is reserved, just link to it
375
+
376
+ return graph.compile()
377
+
378
+ # Expose app
379
+ matching_app = build_matching_graph()
380
+
381
+
382
+ class PostGenState(TypedDict, total=False):
383
+ # Inputs expected from previous step
384
+ product: Dict[str, Any]
385
+ ranked: List[Dict[str, Any]] # from matching: [{id, template, platform, brand_voice, score}, ...]
386
+ platform: str
387
+
388
+ # Post-gen intermediates
389
+ selected_template: Dict[str, Any]
390
+ post_prompt: str
391
+ post_raw_output: str
392
+ post_parsed: Dict[str, Any]
393
+
394
+ # Final
395
+ final_post_struct: Dict[str, Any]
396
+
397
+
398
+ def _get_hf_generator_generator():
399
+
400
+ from transformers import pipeline
401
+ import torch
402
+ global _GENERATOR
403
+
404
+ if _GENERATOR is not None:
405
+ return _GENERATOR
406
+
407
+ hf_token = DEFAULT_CONFIG["providers"]["hf"]["token_gen"]
408
+ if not hf_token:
409
+ raise RuntimeError(
410
+ "❌ Hugging Face token not found. Please set the environment variable HF_TOKEN in your Space settings."
411
+ )
412
+
413
+ # dtype selection
414
+ if torch.cuda.is_available():
415
+ major, _ = torch.cuda.get_device_capability()
416
+ torch_dtype = torch.bfloat16 if major >= 8 else torch.float16
417
+ else:
418
+ torch_dtype = torch.float32
419
+
420
+ try:
421
+ _GENERATOR = pipeline(
422
+ "text-generation",
423
+ model=DEFAULT_CONFIG["postgen"]["MODEL_NAME"], # ✅ fixed typo
424
+ device_map=DEFAULT_CONFIG["postgen"]["HF_DEVICE_MAP"],
425
+ torch_dtype=torch_dtype,
426
+ token=hf_token, # ✅ uses safe env token
427
+ )
428
+ except Exception as e:
429
+ raise RuntimeError(
430
+ f"❌ Failed to load model `{DEFAULT_CONFIG['postgen']['MODEL_NAME']}`. "
431
+ "If it's a gated repo, request access and ensure your HF token has permission. "
432
+ f"Original error: {e}"
433
+ )
434
+
435
+ return _GENERATOR
436
+
437
+
438
+ def build_post_generation_prompt(product, template):
439
+ import json
440
+
441
+ # --- few-shot examples (same as fine-tuning) ---
442
+ few1_product = {
443
+ "name": "Herbal Glow Organic Shampoo",
444
+ "category": "Hair Care",
445
+ "type": "Shampoo",
446
+ "price": 14.99,
447
+ "currency": "USD",
448
+ "description": "Nourishing shampoo made with organic argan oil for smooth, shiny hair.",
449
+ "on_sale": True,
450
+ "options": [{"name": "Size", "value": "250ml"}]
451
+ }
452
+ few1_template = {
453
+ "template": "Say goodbye to dull hair! 🌿 [PRODUCT_NAME] is your go-to [CATEGORY] for silky smooth results — now only [PRICE] [CURRENCY]!",
454
+ "score": 0.88,
455
+ "platform": "Instagram",
456
+ "brand_voice": "Natural & Friendly"
457
+ }
458
+ few1_output = {
459
+ "text": "Say goodbye to dull hair! 🌿 Herbal Glow Organic Shampoo is your go-to hair care for silky smooth results — now only 14.99 USD! 💆‍♀️✨ #HealthyHair #OrganicBeauty",
460
+ "score": 0.95,
461
+ "confidence_breakdown": {"brand_alignment": 0.96, "template_match": 0.88, "clarity_persuasiveness": 0.97}
462
+ }
463
+
464
+ few2_product = {
465
+ "name": "Montre Élégance Argentée",
466
+ "category": "Accessoires",
467
+ "type": "Montre",
468
+ "price": 129.90,
469
+ "currency": "EUR",
470
+ "description": "Montre en acier inoxydable, design raffiné pour toutes les occasions.",
471
+ "on_sale": False,
472
+ "options": [{"name": "Couleur", "value": "Argent"}]
473
+ }
474
+ few2_template = {
475
+ "template": "Découvrez [PRODUCT_NAME] — l’[CATEGORY] parfaite pour sublimer votre style. Prix : [PRICE] [CURRENCY].",
476
+ "score": 0.91,
477
+ "platform": "LinkedIn",
478
+ "brand_voice": "Luxueux et professionnel"
479
+ }
480
+ few2_output = {
481
+ "text": "Découvrez Montre Élégance Argentée — l’accessoire parfait pour sublimer votre style ✨. Prix : 129,90 €. Conçue pour les esprits raffinés et les occasions d’exception. #MontresDeLuxe #Élégance",
482
+ "score": 0.93,
483
+ "confidence_breakdown": {"brand_alignment": 0.94, "template_match": 0.91, "clarity_persuasiveness": 0.94}
484
+ }
485
+
486
+ instructions = """
487
+ You are an expert social-media copywriter AND a marketing evaluator.
488
+ TASK:
489
+ - Replace placeholders in the template (e.g. [PRODUCT_NAME], [CATEGORY], [TYPE], [PRICE], [CURRENCY], [OPTION_VALUE]) with the exact values from the PRODUCT object.
490
+ - Produce a single, ready-to-post marketing text adapted to:
491
+ * the template structure and placeholders,
492
+ * the template.brand_voice (tone & vocabulary),
493
+ * the template.platform (platform-specific style rules below),
494
+ * the product data (use options, on_sale, etc. when relevant).
495
+ - Add emojis and 1–5 hashtags consistent with product, platform, and brand voice.
496
+ - If product.on_sale is True, mention the deal naturally (if it fits the template).
497
+ - Keep language consistent with the template language (if template is French → output in French).
498
+ PLATFORM GUIDELINES (apply strictly):
499
+ - Instagram: eye-catching, up to 5 hashtags, emojis welcome, slightly conversational.
500
+ - TikTok: short, energetic, 1–3 hashtags, call-to-action possible (e.g., "link in bio"), emojis welcome.
501
+ - Facebook: friendly, slightly longer allowed, 1–2 hashtags, 0–2 emojis.
502
+ - X/Twitter: concise (short sentence), 0–2 hashtags, 0–1 emoji.
503
+ - LinkedIn: professional, minimal emojis (0–1), 0–2 hashtags, formal vocabulary.
504
+ - Pinterest: descriptive with keywords/hashtags, minimal emojis.
505
+ SCORING RULE (how to compute final score):
506
+ - brand_alignment = how well tone/emoji/hashtags match template.brand_voice & platform (0.0–1.0).
507
+ - template_match = use template['score'] (0.0–1.0) — this reflects semantic match.
508
+ - clarity_persuasiveness = how clear, persuasive, and well-structured the post is (0.0–1.0).
509
+ - FINAL self_confidence_score = average(brand_alignment, template_match, clarity_persuasiveness). Round to two decimals.
510
+ OUTPUT FORMAT (exact — NO extra text, no JSON wrappers, no commentary):
511
+ text: "<final post text>"
512
+ score: <0.00-1.00>
513
+ confidence_breakdown: {"brand_alignment":X, "template_match":Y, "clarity_persuasiveness":Z}
514
+ (Use dot as decimal separator for scores; keep post language as required.)
515
+ """
516
+
517
+ prompt = (
518
+ instructions.strip() + "\n\n"
519
+ "FEW-SHOT EXAMPLES\n\n"
520
+ "Example 1 INPUT:\nPRODUCT:\n" + json.dumps(few1_product, ensure_ascii=False) + "\nTEMPLATE:\n" + json.dumps(few1_template, ensure_ascii=False) + "\n\n"
521
+ "Example 1 OUTPUT:\ntext: " + json.dumps(few1_output["text"], ensure_ascii=False) + "\n"
522
+ f"score: {few1_output['score']:.2f}\n"
523
+ "confidence_breakdown: " + json.dumps(few1_output["confidence_breakdown"], ensure_ascii=False) + "\n\n"
524
+ "Example 2 INPUT:\nPRODUCT:\n" + json.dumps(few2_product, ensure_ascii=False) + "\nTEMPLATE:\n" + json.dumps(few2_template, ensure_ascii=False) + "\n\n"
525
+ "Example 2 OUTPUT:\ntext: " + json.dumps(few2_output["text"], ensure_ascii=False) + "\n"
526
+ f"score: {few2_output['score']:.2f}\n"
527
+ "confidence_breakdown: " + json.dumps(few2_output["confidence_breakdown"], ensure_ascii=False) + "\n\n"
528
+ "NOW PROCESS THE NEW INPUT\n\n"
529
+ "INPUT PRODUCT:\n" + json.dumps(product, ensure_ascii=False) + "\n\n"
530
+ "INPUT TEMPLATE:\n" + json.dumps(template, ensure_ascii=False) + "\n\n"
531
+ "OUTPUT:\n"
532
+ )
533
+
534
+ return prompt.strip()
535
+
536
+
537
+ def _strip_code_fences(s: str) -> str:
538
+ m = _CODEFence_RE.search(s)
539
+ return m.group(1).strip() if m else s
540
+
541
+ def _safe_json_loads(s: str) -> Optional[dict]:
542
+ try:
543
+ return json.loads(s)
544
+ except Exception:
545
+ # try common cleanups
546
+ s2 = s.replace("“", '"').replace("”", '"').replace("’", "'").replace("‘", "'")
547
+ s2 = re.sub(r",\s*(\}|\])", r"\1", s2) # remove trailing commas
548
+ s2 = s2.replace("'", '"')
549
+ try:
550
+ return json.loads(s2)
551
+ except Exception:
552
+ return None
553
+
554
+
555
+ def parse_post_output_llm(raw: str) -> Dict[str, Any]:
556
+ """
557
+ Expected LLM format (from your prompt):
558
+ text: "<final post text>"
559
+ score: <0.00-1.00>
560
+ confidence_breakdown: {"brand_alignment":X, "template_match":Y, "clarity_persuasiveness":Z}
561
+ Returns dict with keys: text, score, confidence_breakdown (values may be None if missing).
562
+ """
563
+ txt = _strip_code_fences(raw)
564
+
565
+ # text (quoted)
566
+ text_match = re.search(r'text:\s*"(.*?)"', txt, flags=re.DOTALL)
567
+ final_text = text_match.group(1).strip() if text_match else None
568
+
569
+ # score (float)
570
+ score_match = re.search(r'score:\s*([01]?(?:\.\d+)?|\d\.\d+)', txt)
571
+ score_val = float(score_match.group(1)) if score_match else None
572
+
573
+ # confidence_breakdown (JSON-ish dict)
574
+ brk_match = re.search(r'confidence_breakdown:\s*(\{[\s\S]*?\})', txt)
575
+ breakdown = _safe_json_loads(brk_match.group(1)) if brk_match else None
576
+ breakdown = breakdown if isinstance(breakdown, dict) else {}
577
+
578
+ clean_breakdown = {
579
+ "brand_alignment": breakdown.get("brand_alignment", None),
580
+ "template_match": breakdown.get("template_match", None),
581
+ "clarity_persuasiveness": breakdown.get("clarity_persuasiveness", None),
582
+ }
583
+
584
+ return {
585
+ "text": final_text,
586
+ "score": score_val,
587
+ "confidence_breakdown": clean_breakdown,
588
+ }
589
+
590
+
591
+ def node_select_top_template(state: PostGenState) -> PostGenState:
592
+ ranked = state.get("ranked", [])
593
+ if not ranked:
594
+ raise ValueError("PostGen: 'ranked' list is empty or missing.")
595
+ # choose highest score (even if input already sorted)
596
+ best = sorted(ranked, key=lambda x: x.get("score", 0.0), reverse=True)[0]
597
+ return {**state, "selected_template": best}
598
+
599
+
600
+ def node_build_post_prompt(state: PostGenState) -> PostGenState:
601
+ product = state["product"]
602
+ template = state["selected_template"]
603
+ prompt = build_post_generation_prompt(product, template)
604
+ return {**state, "post_prompt": prompt}
605
+
606
+ def node_generate_post_llm(state: PostGenState) -> PostGenState:
607
+ generator = _get_hf_generator_generator()
608
+ prompt = state["post_prompt"]
609
+
610
+ out = generator(
611
+ prompt,
612
+ max_new_tokens=DEFAULT_CONFIG["postgen"]["MAX_NEW_TOKENS"],
613
+ do_sample=True,
614
+ temperature=DEFAULT_CONFIG["postgen"]["TEMPERATURE"],
615
+ top_p=DEFAULT_CONFIG["postgen"]["TOP_P"],
616
+ return_full_text=False,
617
+ )
618
+ raw = out[0]["generated_text"] if isinstance(out, list) and out else str(out)
619
+ return {**state, "post_raw_output": raw}
620
+
621
+
622
+ def node_parse_post_output(state: PostGenState) -> PostGenState:
623
+ raw = state["post_raw_output"]
624
+ parsed = parse_post_output_llm(raw)
625
+ return {**state, "post_parsed": parsed}
626
+
627
+
628
+ def node_merge_post_struct(state: PostGenState) -> PostGenState:
629
+ product = state["product"]
630
+ template = state["selected_template"]
631
+ parsed = state["post_parsed"]
632
+
633
+ final_struct = {
634
+ # IDs come from inputs (NOT from LLM)
635
+ "product_id": product.get("id"),
636
+ "template_id": template.get("id"),
637
+ # LLM-derived
638
+ "final_post": parsed.get("text"),
639
+ "self_confidence_score": parsed.get("score"),
640
+ "confidence_breakdown": parsed.get("confidence_breakdown"),
641
+ }
642
+ return {**state, "final_post_struct": final_struct}
643
+
644
+
645
+ def build_post_generation_graph():
646
+ g = StateGraph(PostGenState)
647
+
648
+ g.add_node("select_top_template", node_select_top_template)
649
+ g.add_node("build_prompt", node_build_post_prompt)
650
+ g.add_node("generate_post", node_generate_post_llm)
651
+ g.add_node("parse_output", node_parse_post_output)
652
+ g.add_node("merge_struct", node_merge_post_struct)
653
+
654
+ g.set_entry_point("select_top_template")
655
+ g.add_edge("select_top_template", "build_prompt")
656
+ g.add_edge("build_prompt", "generate_post")
657
+ g.add_edge("generate_post", "parse_output")
658
+ g.add_edge("parse_output", "merge_struct")
659
+ g.add_edge("merge_struct", END)
660
+
661
+ return g.compile()
662
+ postgen_app=build_post_generation_graph()
663
+
664
+ class PostScheduler:
665
+ def __init__(self, rules_file, timezone_offset=0):
666
+ with open(rules_file, "r") as f:
667
+ self.rules = json.load(f)
668
+ self.timezone_offset = timezone_offset
669
+
670
+ def get_schedule(self, category, platform):
671
+ category = category.lower()
672
+ platform = platform.lower()
673
+ cat_rules = self.rules.get(category, {})
674
+ default_rules = self.rules.get("default", {})
675
+
676
+ if platform in cat_rules:
677
+ slots = cat_rules[platform]
678
+ elif platform in default_rules:
679
+ slots = default_rules[platform]
680
+ else:
681
+ raise ValueError(f"No scheduling rules for {category} or default / {platform}")
682
+
683
+ normalized = []
684
+ for slot in slots:
685
+ expanded = self.normalize_slot(slot, platform, default_rules)
686
+ normalized.extend(expanded)
687
+
688
+ if not normalized:
689
+ # fallback: post tomorrow at 09:00
690
+ scheduled_datetime = datetime.now().replace(hour=9, minute=0, second=0, microsecond=0) + timedelta(days=1)
691
+ return scheduled_datetime.strftime("%Y-%m-%d %H:%M")
692
+
693
+ selected_slot = random.choice(normalized)
694
+ scheduled_datetime = self._parse_slot_to_datetime(selected_slot)
695
+ return scheduled_datetime.strftime("%Y-%m-%d %H:%M")
696
+
697
+ def normalize_slot(self, slot: str, platform: str, default_rules: dict) -> list[str]:
698
+ slot = slot.strip().lower()
699
+ days_map = {
700
+ "weekdays": ["monday","tuesday","wednesday","thursday","friday"],
701
+ "weekend": ["saturday","sunday"]
702
+ }
703
+
704
+ if "platform default" in slot:
705
+ return default_rules.get(platform, []) or []
706
+
707
+ if "weekdays" in slot:
708
+ time = slot.split()[0]
709
+ return [f"{time} {day}" for day in days_map["weekdays"]]
710
+
711
+ if "&" in slot:
712
+ time, days = slot.split(" ", 1)
713
+ expanded_days = [d.strip() for d in days.split("&")]
714
+ return [f"{time} {d}" for d in expanded_days]
715
+
716
+ return [slot]
717
+
718
+ def _parse_slot_to_datetime(self, slot: str) -> datetime:
719
+ now = datetime.now()
720
+ slot = slot.strip()
721
+ time_part = slot.split(" ")[0]
722
+
723
+ if "-" in time_part and ":" in time_part:
724
+ start_time = time_part.split("-")[0]
725
+ else:
726
+ start_time = time_part
727
+
728
+ match = re.match(r"(\d{1,2}):(\d{2})", start_time)
729
+ if not match:
730
+ raise ValueError(f"Invalid time format in slot: {slot}")
731
+
732
+ hour, minute = map(int, match.groups())
733
+ scheduled = now.replace(hour=hour, minute=minute, second=0, microsecond=0)
734
+ scheduled += timedelta(hours=self.timezone_offset)
735
+
736
+ if scheduled <= now:
737
+ scheduled += timedelta(days=1)
738
+
739
+ return scheduled
740
+
741
+ class SchedulingState(TypedDict, total=False):
742
+ product: Dict[str, Any]
743
+ platform: str
744
+ final_post_struct: Dict[str, Any] # re-use directly
745
+ scheduled_post: Dict[str, Any]
746
+
747
+
748
+ from typing import Any, Dict, List, TypedDict
749
+
750
+ class GlobalState(TypedDict, total=False):
751
+ # Matching inputs
752
+ product: Dict[str, Any]
753
+ platform: str
754
+ templates: List[Dict[str, Any]]
755
+
756
+ # Matching outputs
757
+ ranked_templates: List[Dict[str, Any]]
758
+
759
+ # PostGen outputs
760
+ final_post_struct: Dict[str, Any] # product_id, template_id, post text
761
+
762
+ # Scheduling outputs
763
+ scheduled_post: Dict[str, Any]
764
+
765
+ def matching_node(state: dict) -> dict:
766
+ """Run Matching subgraph inside global pipeline."""
767
+ result = matching_app.invoke({
768
+ "product": state["product"],
769
+ "platform": state["platform"],
770
+ "templates": state["templates"],
771
+ "candidate_templates": [],
772
+ "top_k": 10
773
+ })
774
+ state["ranked_templates"] = result["ranked_templates"]
775
+ return state
776
+
777
+ def prepare_for_postgen(state: GlobalState) -> PostGenState:
778
+ """Adapt Matching output to PostGen input format"""
779
+ return {
780
+ "product": state["product"],
781
+ "ranked": state.get("ranked_templates", []),
782
+ "platform": state["platform"]
783
+ }
784
+
785
+
786
+ def postgen_node(state: GlobalState) -> dict:
787
+ """Run Post Generation subgraph inside global pipeline."""
788
+ result = postgen_app.invoke({
789
+ "product": state["product"],
790
+ "ranked": state["ranked_templates"],
791
+ "platform": state["platform"]
792
+ })
793
+ state["final_post_struct"] = result["final_post_struct"]
794
+ return state
795
+
796
+
797
+
798
+ def prepare_for_scheduling(state: GlobalState) -> SchedulingState:
799
+ return {
800
+ "product": state["product"],
801
+ "platform": state["platform"],
802
+ "final_post_struct": state["final_post_struct"], # no renaming
803
+ "scheduled_post": {}
804
+ }
805
+
806
+
807
+ def scheduling_node(state: SchedulingState) -> SchedulingState:
808
+ product = state["product"]
809
+ platform = state["platform"]
810
+ final_post_struct = state["final_post_struct"]
811
+
812
+ category = product.get("Category")
813
+
814
+ scheduler = PostScheduler(rules_file=DEFAULT_CONFIG["scheduling"]["rules_file"])
815
+ scheduled_time = scheduler.get_schedule(category, platform)
816
+
817
+ state["scheduled_post"] = {
818
+ **final_post_struct,
819
+ "scheduled_time": scheduled_time,
820
+ }
821
+ return state
822
+
823
+ def build_global_graph():
824
+ g = StateGraph(GlobalState)
825
+
826
+ # Nodes
827
+ g.add_node("matching", matching_node)
828
+ g.add_node("prepare_for_postgen", prepare_for_postgen)
829
+ g.add_node("postgen", postgen_node)
830
+ g.add_node("prepare_for_scheduling", prepare_for_scheduling)
831
+ g.add_node("scheduling", scheduling_node)
832
+
833
+ # Flow
834
+ g.set_entry_point("matching")
835
+ g.add_edge("matching", "prepare_for_postgen")
836
+ g.add_edge("prepare_for_postgen", "postgen")
837
+ g.add_edge("postgen","prepare_for_scheduling" )
838
+ g.add_edge("prepare_for_scheduling", "scheduling")
839
+ g.add_edge("scheduling", END)
840
+
841
+ return g.compile()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langgraph
2
+ langdetect
3
+ python-dotenv
4
+ torch
5
+ transformers
6
+ pyyaml
7
+ typing-extensions
8
+ regex
9
+ streamlit
10
+ accelerate
11
+ sentencepiece
12
+ huggingface-hub