mubashirhussaindev commited on
Commit
61b1afa
·
verified ·
1 Parent(s): 8b07697

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +65 -208
agent.py CHANGED
@@ -1,226 +1,83 @@
1
- from langchain_huggingface import HuggingFacePipeline
2
  from transformers import pipeline
3
  import requests
4
  from bs4 import BeautifulSoup
5
- import os
6
- from typing import List, TypedDict
7
- from dataclasses import dataclass
8
- from langgraph.graph import StateGraph, START, END
9
- from dotenv import load_dotenv
10
  from time import sleep
11
 
12
- load_dotenv()
13
-
14
- @dataclass
15
- class Command:
16
- update: dict = None
17
- goto: str = None
18
-
19
- # Initialize local Hugging Face model with error handling
20
  try:
21
- hf_pipeline = pipeline(
22
- "text2text-generation",
23
- model="google/flan-t5-small",
24
- max_length=512,
25
- temperature=0.7,
26
- )
27
- model = HuggingFacePipeline(pipeline=hf_pipeline)
28
  except Exception as e:
29
- print(f"Error initializing model pipeline: {e}")
30
- model = None # fallback
31
-
32
- def scrape_startpage(query: str, max_results: int = 3) -> List[dict]:
33
- url = f"https://www.startpage.com/sp/search?query={query.replace(' ', '+')}"
34
- headers = {
35
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
36
- }
37
- for attempt in range(3):
38
- try:
39
- response = requests.get(url, headers=headers, timeout=10)
40
- response.raise_for_status()
41
- soup = BeautifulSoup(response.text, "html.parser")
42
- results = []
43
- for result in soup.find_all("div", class_="result")[:max_results]:
44
- title = result.find("h3") or result.find("a")
45
- snippet = result.find("p", class_="desc")
46
- title_text = title.get_text().strip() if title else "No title"
47
- snippet_text = snippet.get_text().strip() if snippet else "No snippet"
48
- results.append({"title": title_text, "snippet": snippet_text})
49
- return results
50
- except Exception as e:
51
- print(f"Error scraping Startpage (attempt {attempt+1}/3): {e}")
52
- sleep(2 ** attempt)
53
- return []
54
 
55
  def safe_invoke(prompt: str) -> str:
56
- if not model:
57
- return "Error: model not initialized"
58
  try:
59
- response = model.invoke(prompt)
60
- if not response or not isinstance(response, str):
61
- return "Error: empty or invalid response"
62
- return response.strip()
63
- except Exception as e:
64
- print(f"Error invoking model: {e}")
65
  return "Error"
66
-
67
- def get_platform_tips(state) -> Command:
68
- query = f"tips on how to write an effective post on {state.get('platform', '')}"
69
- results = scrape_startpage(query)
70
- if results:
71
- prompt = f"Summarize these tips for writing effective posts on {state.get('platform', '')}: {results}"
72
- response = safe_invoke(prompt)
73
- else:
74
- response = f"Use a professional tone, include a clear call-to-action, and keep the post concise for {state.get('platform', '')}."
75
- return Command(update={"tips": response}, goto="web_search")
76
-
77
- def web_search(state) -> Command:
78
- search_results = scrape_startpage(state.get("topic", ""), max_results=3)
79
- return Command(update={"search_results": search_results}, goto="generate_post")
80
-
81
- def generate_social_media_post(state) -> Command:
82
- try:
83
- prompt = f"""
84
- You are a social media strategist for a B2B bank. Generate a {state.get("platform", "")} post.
85
- The post should:
86
- - Be engaging but professional.
87
- - Provide value to corporate clients.
88
- - Focus on {state.get("topic", "")}.
89
- - Incorporate information from {state.get("search_results", "general knowledge about the topic")}
90
- Output as plain text.
91
- """
92
- response = safe_invoke(prompt)
93
- return Command(update={"post": response}, goto="evaluate_engagement")
94
- except Exception as e:
95
- print(f"Error in generate_social_media_post: {e}")
96
- return Command(update={"post": "Error generating post"}, goto="evaluate_engagement")
97
-
98
- def evaluate_engagement(state) -> Command:
99
- prompt = f"""
100
- Score the following post on engagement (1-10) based on the provided social media platform.
101
- Consider clarity, readability, and compelling call-to-action.
102
- Platform: {state.get("platform", "")}
103
- Post: {state.get("post", "")}
104
- Respond with only a number between 1 and 10, no text.
105
- """
106
- score = safe_invoke(prompt)
107
- return Command(update={"engagement_score": score}, goto="evaluate_tone")
108
-
109
- def evaluate_tone(state) -> Command:
110
- prompt = f"""
111
- Score the post’s tone (1-10). Ensure it's:
112
- - Professional but not too rigid.
113
- - Trustworthy and aligned with B2B financial services.
114
- - Aligns with the specified platform.
115
- Platform: {state.get("platform", "")}
116
- Post: {state.get("post", "")}
117
- Respond with only a number between 1 and 10, no text.
118
- """
119
- score = safe_invoke(prompt)
120
- return Command(update={"tone_score": score}, goto="evaluate_clarity")
121
-
122
- def evaluate_clarity(state) -> Command:
123
- prompt = f"""
124
- Score the post on clarity (1-10).
125
- - Avoids jargon.
126
- - Easy to read for busy corporate professionals.
127
- - Appropriate for the social media platform.
128
- Platform: {state.get("platform", "")}
129
- Post: {state.get("post", "")}
130
- Respond with only a number between 1 and 10, no text.
131
- """
132
- score = safe_invoke(prompt)
133
- return Command(update={"clarity_score": score}, goto="revise_if_needed")
134
-
135
- def revise_if_needed(state) -> Command:
136
- try:
137
- scores = [int(state.get("engagement_score", "0")),
138
- int(state.get("tone_score", "0")),
139
- int(state.get("clarity_score", "0"))]
140
  except Exception as e:
141
- print(f"Error parsing scores: {e}")
142
- return Command(update={"post": "Error: Invalid scores"}, goto="get_image")
143
-
144
- avg_score = sum(scores) / len(scores)
145
- if avg_score < 7:
146
- prompt = f"""
147
- Revise this post to improve clarity, engagement, and tone:
148
-
149
- {state.get("post", "")}
150
-
151
- Improve based on the following scores:
152
- Engagement: {state.get("engagement_score", "0")}
153
- Tone: {state.get("tone_score", "0")}
154
- Clarity: {state.get("clarity_score", "0")}
155
- """
156
- revised_post = safe_invoke(prompt)
157
- return Command(update={"post": revised_post}, goto="get_image")
158
-
159
- return Command(goto="get_image")
160
-
161
- def fetch_image(state) -> Command:
162
- pexels_key = os.getenv("PEXELS_API_KEY")
163
- if not pexels_key:
164
- print("PEXELS_API_KEY not set in environment")
165
- return Command(update={"image_url": []}, goto=END)
166
-
167
- prompt = f"""
168
- You are a search optimization assistant. Your task is to take a topic and improve it to ensure the best image results from an image search API like Pexels.
169
- Topic: {state.get('topic', '')}
170
- """
171
- query = safe_invoke(prompt)
172
- if "Error" in query or not query:
173
- query = state.get('topic', '')
174
 
175
- url = "https://api.pexels.com/v1/search"
176
- params = {
177
- "query": query,
178
- "per_page": 5,
179
- "page": 1
180
- }
181
- headers = {
182
- "Authorization": pexels_key
183
- }
184
  for attempt in range(3):
185
  try:
186
- response = requests.get(url, headers=headers, params=params)
187
- response.raise_for_status()
188
- data = response.json()
189
- urls = [photo['url'] for photo in data.get('photos', [])]
190
- return Command(update={"image_url": urls}, goto=END)
 
 
 
 
 
 
 
191
  except Exception as e:
192
- print(f"Error fetching images from Pexels (attempt {attempt+1}/3): {e}")
193
  sleep(2 ** attempt)
194
- return Command(update={"image_url": []}, goto=END)
195
-
196
- class State(TypedDict, total=False):
197
- topic: str
198
- platform: str
199
- tips: str
200
- search_results: List[dict]
201
- post: str
202
- engagement_score: str
203
- tone_score: str
204
- clarity_score: str
205
- image_url: List[str]
206
-
207
- workflow = StateGraph(State)
208
- workflow.add_node("get_tips", get_platform_tips)
209
- workflow.add_node("web_search", web_search)
210
- workflow.add_node("generate_post", generate_social_media_post)
211
- workflow.add_node("evaluate_engagement", evaluate_engagement)
212
- workflow.add_node("evaluate_tone", evaluate_tone)
213
- workflow.add_node("evaluate_clarity", evaluate_clarity)
214
- workflow.add_node("revise_if_needed", revise_if_needed)
215
- workflow.add_node("get_image", fetch_image)
216
-
217
- workflow.add_edge(START, "get_tips")
218
- workflow.add_edge("get_tips", "web_search")
219
- workflow.add_edge("web_search", "generate_post")
220
- workflow.add_edge("generate_post", "evaluate_engagement")
221
- workflow.add_edge("evaluate_engagement", "evaluate_tone")
222
- workflow.add_edge("evaluate_tone", "evaluate_clarity")
223
- workflow.add_edge("evaluate_clarity", "revise_if_needed")
224
- workflow.add_edge("revise_if_needed", "get_image")
225
 
226
- graph = workflow.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import pipeline
2
  import requests
3
  from bs4 import BeautifulSoup
 
 
 
 
 
4
  from time import sleep
5
 
6
+ # Load the model once
 
 
 
 
 
 
 
7
  try:
8
+ hf_pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=512, temperature=0.7)
 
 
 
 
 
 
9
  except Exception as e:
10
+ print(f"Error loading model: {e}")
11
+ hf_pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def safe_invoke(prompt: str) -> str:
14
+ if not hf_pipe:
15
+ return "Error: Model not loaded"
16
  try:
17
+ outputs = hf_pipe(prompt)
18
+ if outputs and isinstance(outputs, list):
19
+ return outputs[0]['generated_text'].strip()
 
 
 
20
  return "Error"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  except Exception as e:
22
+ print(f"Error during generation: {e}")
23
+ return "Error"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def scrape_startpage(query: str, max_results: int = 3):
26
+ url = f"https://www.startpage.com/sp/search?query={query.replace(' ', '+')}"
27
+ headers = {"User-Agent": "Mozilla/5.0"}
 
 
 
 
 
 
28
  for attempt in range(3):
29
  try:
30
+ res = requests.get(url, headers=headers, timeout=10)
31
+ res.raise_for_status()
32
+ soup = BeautifulSoup(res.text, "html.parser")
33
+ results = []
34
+ # Find divs with class "result" (Startpage search results)
35
+ for r in soup.find_all("div", class_="result")[:max_results]:
36
+ title = r.find("h3")
37
+ desc = r.find("p", class_="desc")
38
+ title_text = title.get_text(strip=True) if title else "No title"
39
+ desc_text = desc.get_text(strip=True) if desc else "No description"
40
+ results.append(f"{title_text}: {desc_text}")
41
+ return results
42
  except Exception as e:
43
+ print(f"Scrape error (attempt {attempt+1}): {e}")
44
  sleep(2 ** attempt)
45
+ return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ def generate_post(topic, platform, search_results):
48
+ base_prompt = f"""You are a social media expert. Write a professional {platform} post about "{topic}".
49
+ Use this information to help you: {search_results}
50
+ Make the post clear, engaging, and suitable for corporate clients.
51
+ Output only the post text."""
52
+ return safe_invoke(base_prompt)
53
+
54
+ def score_post(post, platform, score_type):
55
+ prompt = f"""Rate the following post on {score_type} from 1 to 10 (just give a number):
56
+ Platform: {platform}
57
+ Post: {post}
58
+ """
59
+ return safe_invoke(prompt)
60
+
61
+ def workflow(topic, platform):
62
+ # Step 1: Web search results
63
+ search_results = scrape_startpage(topic)
64
+ combined_results = " | ".join(search_results) if search_results else "No additional info."
65
+
66
+ # Step 2: Generate post
67
+ post = generate_post(topic, platform, combined_results)
68
+ if post == "Error":
69
+ return post, "Error", "Error", "Error"
70
+
71
+ # Step 3: Scores
72
+ engagement = score_post(post, platform, "engagement")
73
+ tone = score_post(post, platform, "tone")
74
+ clarity = score_post(post, platform, "clarity")
75
+
76
+ # Validate scores (should be digits)
77
+ def valid_score(s):
78
+ return s and s.strip().isdigit()
79
+
80
+ if not all(map(valid_score, (engagement, tone, clarity))):
81
+ return post, "Error", "Error", "Error"
82
+
83
+ return post, engagement, tone, clarity