File size: 2,783 Bytes
28f56d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from langchain_groq import ChatGroq
import os
from dotenv import load_dotenv
import pandas as pd
import json
class FewShotPosts:
def __init__(self, file_path="processed_posts.json"):
self.df = None
self.unique_tags = None
self.load_posts(file_path)
def load_posts(self, file_path):
with open(file_path, encoding="utf-8") as f:
posts = json.load(f)
self.df = pd.json_normalize(posts)
self.df['length'] = self.df['line_count'].apply(self.categorize_length)
# collect unique tags
all_tags = self.df['tags'].apply(lambda x: x).sum()
self.unique_tags = list(set(all_tags))
def get_filtered_posts(self, length, language, tag):
df_filtered = self.df[
(self.df['tags'].apply(lambda tags: tag in tags)) & # Tags contain 'Influencer'
(self.df['language'] == language) & # Language is 'English'
(self.df['length'] == length) # Line count is less than 5
]
return df_filtered.to_dict(orient='records')
def categorize_length(self, line_count):
if line_count < 5:
return "Short"
elif 5 <= line_count <= 10:
return "Medium"
else:
return "Long"
def get_tags(self):
return self.unique_tags
load_dotenv()
llm = ChatGroq(groq_api_key=os.getenv("GROQ_API_KEY"), model_name="llama-3.3-70b-versatile")
few_shot = FewShotPosts()
def get_length_str(length):
if length == "Short":
return "1 to 5 lines"
if length == "Medium":
return "6 to 10 lines"
if length == "Long":
return "11 to 15 lines"
def generate_post(length, language, tag):
prompt = get_prompt(length, language, tag)
response = llm.invoke(prompt)
return response.content
def get_prompt(length, language, tag):
length_str = get_length_str(length)
prompt = f'''
Generate a LinkedIn post using the below information. No preamble.
1) Topic: {tag}
2) Length: {length_str}
3) Language: {language}
If Language is Hinglish then it means it is a mix of Hindi and English.
The script for the generated post should always be English.
'''
# prompt = prompt.format(post_topic=tag, post_length=length_str, post_language=language)
examples = few_shot.get_filtered_posts(length, language, tag)
if len(examples) > 0:
prompt += "4) Use the writing style as per the following examples."
for i, post in enumerate(examples):
post_text = post['text']
prompt += f'\n\n Example {i+1}: \n\n {post_text}'
if i == 1: # Use max two samples
break
return prompt
if __name__ == "__main__":
print(generate_post("Medium", "English", "Mental Health")) |