prasannahf commited on
Commit
e1c8621
Β·
verified Β·
1 Parent(s): 4b0ede1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -97
app.py CHANGED
@@ -5,22 +5,23 @@ import torch
5
  from langgraph.graph import StateGraph, START, END
6
  from langchain.schema import HumanMessage
7
  from langchain_groq import ChatGroq
8
- from langsmith import traceable # βœ… Added LangSmith for Debugging
9
  from typing import TypedDict
10
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
11
 
12
- # βœ… Load API keys from Hugging Face Secrets
13
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
14
  LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")
15
 
16
- # βœ… Set LangSmith Debugging
 
17
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
18
  os.environ["LANGCHAIN_API_KEY"] = LANGSMITH_API_KEY
19
 
20
- # βœ… Initialize Groq LLM (for content generation)
21
  llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="mixtral-8x7b-32768")
22
 
23
- # βœ… Define State for LangGraph
24
  class State(TypedDict):
25
  topic: str
26
  titles: list
@@ -31,19 +32,19 @@ class State(TypedDict):
31
  tone: str
32
  language: str
33
 
34
- # βœ… Function to generate multiple blog titles using Groq
35
- @traceable(name="Generate Titles") # βœ… Debugging with LangSmith
36
  def generate_titles(data):
37
  topic = data.get("topic", "")
38
  prompt = f"Generate three short and catchy blog titles for the topic: {topic}. Each title should be under 10 words. Separate them with new lines."
39
 
40
  response = llm([HumanMessage(content=prompt)])
41
- titles = response.content.strip().split("\n")
42
 
43
- return {"titles": titles, "selected_title": titles[0]}
44
 
45
- # βœ… Function to generate blog content with tone using Groq
46
- @traceable(name="Generate Content") # βœ… Debugging with LangSmith
47
  def generate_content(data):
48
  title = data.get("selected_title", "")
49
  tone = data.get("tone", "Neutral")
@@ -52,8 +53,8 @@ def generate_content(data):
52
  response = llm([HumanMessage(content=prompt)])
53
  return {"content": response.content.strip()}
54
 
55
- # βœ… Function to generate summary using Groq
56
- @traceable(name="Generate Summary") # βœ… Debugging with LangSmith
57
  def generate_summary(data):
58
  content = data.get("content", "")
59
  prompt = f"Summarize this blog post in a short and engaging way: {content}"
@@ -61,16 +62,12 @@ def generate_summary(data):
61
  response = llm([HumanMessage(content=prompt)])
62
  return {"summary": response.content.strip()}
63
 
64
- # βœ… Load translation model (NLLB-200)
65
- def load_translation_model():
66
- model_name = "facebook/nllb-200-distilled-600M"
67
- tokenizer = AutoTokenizer.from_pretrained(model_name)
68
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
69
- return tokenizer, model
70
 
71
- tokenizer, model = load_translation_model()
72
-
73
- # βœ… Language codes for NLLB-200
74
  language_codes = {
75
  "English": "eng_Latn",
76
  "Hindi": "hin_Deva",
@@ -79,109 +76,74 @@ language_codes = {
79
  "French": "fra_Latn"
80
  }
81
 
82
- # βœ… Function to translate blog content using NLLB-200
83
- @traceable(name="Translate Content") # βœ… Debugging with LangSmith
84
  def translate_content(data):
85
  content = data.get("content", "")
86
  language = data.get("language", "English")
87
 
88
  if language == "English":
89
- return {"translated_content": content}
90
-
91
- tgt_lang = language_codes.get(language, "eng_Latn")
92
-
93
- # βœ… Split content into smaller chunks (Avoids token limit issues)
94
- max_length = 512
95
- sentences = content.split(". ")
96
- chunks = []
97
- current_chunk = ""
98
-
99
- for sentence in sentences:
100
- if len(current_chunk) + len(sentence) < max_length:
101
- current_chunk += sentence + ". "
102
- else:
103
- chunks.append(current_chunk.strip())
104
- current_chunk = sentence + ". "
105
 
106
- if current_chunk:
107
- chunks.append(current_chunk.strip())
 
 
108
 
109
- # βœ… Translate each chunk separately and combine results
110
- translated_chunks = []
111
- for chunk in chunks:
112
- inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
113
- translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang))
114
- translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
115
- translated_chunks.append(translated_text.strip())
116
 
117
- full_translation = " ".join(translated_chunks)
118
-
119
- return {"translated_content": full_translation}
120
-
121
- # βœ… Create LangGraph Workflow
122
  def make_blog_generation_graph():
123
- """Create a LangGraph workflow for Blog Generation"""
124
  graph_workflow = StateGraph(State)
125
-
126
- # Define Nodes
127
  graph_workflow.add_node("title_generation", generate_titles)
128
  graph_workflow.add_node("content_generation", generate_content)
129
  graph_workflow.add_node("summary_generation", generate_summary)
130
- graph_workflow.add_node("translation", translate_content)
131
-
132
- # Define Execution Order
133
  graph_workflow.add_edge(START, "title_generation")
134
  graph_workflow.add_edge("title_generation", "content_generation")
135
  graph_workflow.add_edge("content_generation", "summary_generation")
136
  graph_workflow.add_edge("content_generation", "translation")
137
  graph_workflow.add_edge("summary_generation", END)
138
  graph_workflow.add_edge("translation", END)
139
-
140
  return graph_workflow.compile()
141
 
142
- # βœ… Function to generate blog content (Fixed)
143
- def generate_blog(topic, tone, language):
144
  try:
145
  if not topic:
146
  return "⚠️ Please enter a topic.", "", "", "", ""
147
-
148
  blog_agent = make_blog_generation_graph()
149
- result = blog_agent.invoke({"topic": topic, "tone": tone, "language": language})
150
-
151
- return result["titles"], result["selected_title"], result["content"], result["summary"], result["translated_content"]
152
 
153
  except Exception as e:
154
  error_message = f"⚠️ Error: {str(e)}\n{traceback.format_exc()}"
155
  return error_message, "", "", "", ""
156
 
157
- # βœ… Gradio UI
158
  with gr.Blocks() as app:
159
- gr.Markdown(
160
- """
161
- ### 🌍 Why Translate?
162
- - πŸ—£οΈ **Multilingual Support**
163
- - 🌎 **Expand Reach**
164
- - βœ… **Better Understanding**
165
- - πŸ€– **AI-Powered Accuracy**
166
- """
167
- )
168
-
169
- gr.Interface(
170
- fn=generate_blog,
171
- inputs=[
172
- gr.Textbox(label="Enter a topic for your blog"),
173
- gr.Dropdown(["Neutral", "Formal", "Casual", "Persuasive", "Humorous"], label="Select Blog Tone", value="Neutral"),
174
- gr.Dropdown(["English", "Hindi", "Telugu", "Spanish", "French"], label="Translate Blog To", value="English"),
175
- ],
176
- outputs=[
177
- gr.Textbox(label="Suggested Blog Titles"),
178
- gr.Textbox(label="Selected Blog Title"),
179
- gr.Textbox(label="Generated Blog Content"),
180
- gr.Textbox(label="Blog Summary"),
181
- gr.Textbox(label="Translated Blog Content"),
182
- ],
183
- title="πŸš€ AI-Powered Blog Generator",
184
- )
185
-
186
- # βœ… Launch the Gradio App
187
  app.launch(share=True)
 
5
  from langgraph.graph import StateGraph, START, END
6
  from langchain.schema import HumanMessage
7
  from langchain_groq import ChatGroq
8
+ from langsmith import traceable
9
  from typing import TypedDict
10
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
11
 
12
+ # Load API Keys
13
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
14
  LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")
15
 
16
+ # Set environment variables
17
+ os.environ["GROQ_API_KEY"] = GROQ_API_KEY
18
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
19
  os.environ["LANGCHAIN_API_KEY"] = LANGSMITH_API_KEY
20
 
21
+ # Initialize Groq LLM
22
  llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="mixtral-8x7b-32768")
23
 
24
+ # Define State for LangGraph
25
  class State(TypedDict):
26
  topic: str
27
  titles: list
 
32
  tone: str
33
  language: str
34
 
35
+ # Function to generate multiple blog titles
36
+ @traceable(name="Generate Titles")
37
  def generate_titles(data):
38
  topic = data.get("topic", "")
39
  prompt = f"Generate three short and catchy blog titles for the topic: {topic}. Each title should be under 10 words. Separate them with new lines."
40
 
41
  response = llm([HumanMessage(content=prompt)])
42
+ titles = response.content.strip().split("\n") # Get three titles as a list
43
 
44
+ return {"titles": titles} # No default selection
45
 
46
+ # Function to generate blog content with user-selected title
47
+ @traceable(name="Generate Content")
48
  def generate_content(data):
49
  title = data.get("selected_title", "")
50
  tone = data.get("tone", "Neutral")
 
53
  response = llm([HumanMessage(content=prompt)])
54
  return {"content": response.content.strip()}
55
 
56
+ # Function to generate summary
57
+ @traceable(name="Generate Summary")
58
  def generate_summary(data):
59
  content = data.get("content", "")
60
  prompt = f"Summarize this blog post in a short and engaging way: {content}"
 
62
  response = llm([HumanMessage(content=prompt)])
63
  return {"summary": response.content.strip()}
64
 
65
+ # Load translation model
66
+ model_name = "facebook/nllb-200-distilled-600M"
67
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
68
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
69
 
70
+ # Language codes
 
 
71
  language_codes = {
72
  "English": "eng_Latn",
73
  "Hindi": "hin_Deva",
 
76
  "French": "fra_Latn"
77
  }
78
 
79
+ # Function to translate blog content
80
+ @traceable(name="Translate Content")
81
  def translate_content(data):
82
  content = data.get("content", "")
83
  language = data.get("language", "English")
84
 
85
  if language == "English":
86
+ return {"translated_content": content} # No translation needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ tgt_lang = language_codes.get(language, "eng_Latn")
89
+ inputs = tokenizer(content, return_tensors="pt", padding=True, truncation=True)
90
+ translated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.convert_tokens_to_ids(tgt_lang))
91
+ translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
92
 
93
+ return {"translated_content": translated_text.strip()}
 
 
 
 
 
 
94
 
95
+ # Create LangGraph Workflow
 
 
 
 
96
  def make_blog_generation_graph():
 
97
  graph_workflow = StateGraph(State)
98
+
 
99
  graph_workflow.add_node("title_generation", generate_titles)
100
  graph_workflow.add_node("content_generation", generate_content)
101
  graph_workflow.add_node("summary_generation", generate_summary)
102
+ graph_workflow.add_node("translation", translate_content)
103
+
 
104
  graph_workflow.add_edge(START, "title_generation")
105
  graph_workflow.add_edge("title_generation", "content_generation")
106
  graph_workflow.add_edge("content_generation", "summary_generation")
107
  graph_workflow.add_edge("content_generation", "translation")
108
  graph_workflow.add_edge("summary_generation", END)
109
  graph_workflow.add_edge("translation", END)
110
+
111
  return graph_workflow.compile()
112
 
113
+ # Gradio Interface
114
+ def generate_blog(topic, tone, language, selected_title):
115
  try:
116
  if not topic:
117
  return "⚠️ Please enter a topic.", "", "", "", ""
118
+
119
  blog_agent = make_blog_generation_graph()
120
+ result = blog_agent.invoke({"topic": topic, "tone": tone, "language": language, "selected_title": selected_title})
121
+
122
+ return result["titles"], selected_title, result["content"], result["summary"], result["translated_content"]
123
 
124
  except Exception as e:
125
  error_message = f"⚠️ Error: {str(e)}\n{traceback.format_exc()}"
126
  return error_message, "", "", "", ""
127
 
 
128
  with gr.Blocks() as app:
129
+ gr.Markdown("""
130
+ ### 🌍 Why Translate?
131
+ We provide translation to make the blog content **accessible to a global audience**.
132
+ - πŸ—£οΈ **Multilingual Support** – Read blogs in your preferred language.
133
+ - 🌎 **Expand Reach** – Reach international readers.
134
+ - βœ… **Better Understanding** – Enjoy content in a language you're comfortable with.
135
+ - πŸ€– **AI-Powered Accuracy** – Uses advanced AI models for precise translation.
136
+ """)
137
+
138
+ topic_input = gr.Textbox(label="Enter a topic")
139
+ tone_input = gr.Dropdown(["Neutral", "Formal", "Casual"], label="Select Blog Tone")
140
+ language_input = gr.Dropdown(["English", "Hindi", "Telugu", "Spanish", "French"], label="Translate Blog To")
141
+ title_output = gr.Dropdown(label="Select Blog Title")
142
+ content_output = gr.Textbox(label="Generated Blog Content")
143
+ summary_output = gr.Textbox(label="Blog Summary")
144
+ translation_output = gr.Textbox(label="Translated Blog Content")
145
+
146
+ generate_button = gr.Button("Generate Blog")
147
+ generate_button.click(generate_blog, [topic_input, tone_input, language_input, title_output], [title_output, content_output, summary_output, translation_output])
148
+
 
 
 
 
 
 
 
 
149
  app.launch(share=True)