Alpha108 commited on
Commit
33e6a1b
·
verified ·
1 Parent(s): 4b9478f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -8
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import streamlit as st
3
  from llm_groq import generate_post, transform_post, generate_hooks, DEFAULT_MODEL
4
  from prompts import build_quick_prompt, build_post_prompt, transform_instruction
@@ -8,7 +9,7 @@ from ui_components import quick_controls, pro_controls
8
  st.set_page_config(page_title="LinkedIn Post Generator — Groq", layout="centered")
9
  st.title("LinkedIn Post Generator — Quick & Pro (Groq)")
10
 
11
- # Sidebar (unique keys not required here but safe to add if duplicated later)
12
  with st.sidebar:
13
  st.subheader("Groq & Decoding")
14
  model = st.selectbox("Model", [DEFAULT_MODEL, "llama-3.1-8b-instant", "mixtral-8x7b-32768"], index=0, key="sb_model")
@@ -21,6 +22,21 @@ tabs = st.tabs(["Quick Draft", "Pro Mode", "History"])
21
  if "history" not in st.session_state:
22
  st.session_state.history = []
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Quick Draft
25
  with tabs[0]:
26
  idea, tone, words, variations, include_emoji, add_hashtags, language = quick_controls()
@@ -38,7 +54,18 @@ with tabs[0]:
38
  for _ in range(variations):
39
  raw = generate_post(prompt, model, temperature, top_p, max_tokens)
40
  clean = dedupe_sentences(strip_labels(raw))
41
- posts.append(clean)
 
 
 
 
 
 
 
 
 
 
 
42
  except Exception as e:
43
  st.error(f"Generation failed: {e}")
44
  posts = []
@@ -94,7 +121,7 @@ with tabs[1]:
94
 
95
  # Refinements
96
  col1,col2,col3,col4,col5 = st.columns(5)
97
- def refine(kind, key_suffix):
98
  if st.session_state.get("history") and st.session_state.history[-1].get("post"):
99
  try:
100
  instr = transform_instruction(kind)
@@ -103,11 +130,11 @@ with tabs[1]:
103
  except Exception as e:
104
  st.error(f"Refinement failed: {e}")
105
 
106
- if col1.button("Shorter", key="pro_shorter"): refine("shorter","shorter")
107
- if col2.button("Punchier hook", key="pro_punchy"): refine("punchier","punchy")
108
- if col3.button("Add data point", key="pro_adddata"): refine("add_data","adddata")
109
- if col4.button("No emojis", key="pro_noemoji"): refine("less_emoji","noemoji")
110
- if col5.button("Add hashtags", key="pro_addtags"): refine("add_tags","addtags")
111
 
112
  if st.session_state.get("history") and st.session_state.history[-1].get("post"):
113
  st.write(st.session_state.history[-1]["post"])
 
1
  import os
2
+ import re
3
  import streamlit as st
4
  from llm_groq import generate_post, transform_post, generate_hooks, DEFAULT_MODEL
5
  from prompts import build_quick_prompt, build_post_prompt, transform_instruction
 
9
  st.set_page_config(page_title="LinkedIn Post Generator — Groq", layout="centered")
10
  st.title("LinkedIn Post Generator — Quick & Pro (Groq)")
11
 
12
+ # Sidebar
13
  with st.sidebar:
14
  st.subheader("Groq & Decoding")
15
  model = st.selectbox("Model", [DEFAULT_MODEL, "llama-3.1-8b-instant", "mixtral-8x7b-32768"], index=0, key="sb_model")
 
22
  if "history" not in st.session_state:
23
  st.session_state.history = []
24
 
25
+ def quick_quality_fix(text, want_hashtags=True, allow_emoji=True):
26
+ lines = [l for l in text.strip().splitlines() if l.strip()]
27
+ if len(lines) < 4 or len(lines) > 7:
28
+ return None
29
+ if not allow_emoji:
30
+ text = re.sub(r"[^\w\s#.,:;%&()\-\+\[\]{}'\"/]", "", text)
31
+ tags = re.findall(r"#\w+", text)
32
+ if not want_hashtags and tags:
33
+ for t in tags:
34
+ text = text.replace(t, "")
35
+ if want_hashtags and len(tags) > 2:
36
+ for t in tags[2:]:
37
+ text = text.replace(t, "")
38
+ return text.strip()
39
+
40
  # Quick Draft
41
  with tabs[0]:
42
  idea, tone, words, variations, include_emoji, add_hashtags, language = quick_controls()
 
54
  for _ in range(variations):
55
  raw = generate_post(prompt, model, temperature, top_p, max_tokens)
56
  clean = dedupe_sentences(strip_labels(raw))
57
+ fixed = quick_quality_fix(clean, want_hashtags=add_hashtags, allow_emoji=include_emoji)
58
+ if fixed is None:
59
+ corrective = (
60
+ prompt + "\n\nRegenerate as 4–6 short lines, each under 18 words, "
61
+ "include one concrete metric or date, "
62
+ f"{'max 1 emoji' if include_emoji else 'no emojis'}, "
63
+ f"{'1–2 niche hashtags at the end' if add_hashtags else 'no hashtags'}."
64
+ )
65
+ raw2 = generate_post(corrective, model, temperature, top_p, max_tokens)
66
+ clean2 = dedupe_sentences(strip_labels(raw2))
67
+ fixed = quick_quality_fix(clean2, want_hashtags=add_hashtags, allow_emoji=include_emoji) or clean2
68
+ posts.append(fixed)
69
  except Exception as e:
70
  st.error(f"Generation failed: {e}")
71
  posts = []
 
121
 
122
  # Refinements
123
  col1,col2,col3,col4,col5 = st.columns(5)
124
+ def refine(kind):
125
  if st.session_state.get("history") and st.session_state.history[-1].get("post"):
126
  try:
127
  instr = transform_instruction(kind)
 
130
  except Exception as e:
131
  st.error(f"Refinement failed: {e}")
132
 
133
+ if col1.button("Shorter", key="pro_shorter"): refine("shorter")
134
+ if col2.button("Punchier hook", key="pro_punchy"): refine("punchier")
135
+ if col3.button("Add data point", key="pro_adddata"): refine("add_data")
136
+ if col4.button("No emojis", key="pro_noemoji"): refine("less_emoji")
137
+ if col5.button("Add hashtags", key="pro_addtags"): refine("add_tags")
138
 
139
  if st.session_state.get("history") and st.session_state.history[-1].get("post"):
140
  st.write(st.session_state.history[-1]["post"])