GLAkavya commited on
Commit
612746c
·
verified ·
1 Parent(s): c8b991f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -110
app.py CHANGED
@@ -1,127 +1,249 @@
1
  import os
2
  import random
3
- import streamlit as st
4
- import pandas as pd
5
- import plotly.express as px
 
 
 
6
  import google.generativeai as genai
7
  from transformers import pipeline
8
 
9
- # ========== API KEYS ==========
10
- genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
11
- sentiment_pipeline = pipeline("sentiment-analysis")
 
 
12
 
13
- # ========== POST GENERATOR ==========
14
- def generate_posts(hashtag, n=20):
15
- """Try Gemini first, fallback to HuggingFace generated posts"""
16
- posts = []
17
- source = "Gemini"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
 
 
 
 
19
  try:
20
- prompt = f"Generate {n} realistic, diverse, short social media posts about {hashtag}. Use emojis."
21
- response = genai.GenerativeModel("gemini-1.5-flash").generate_content(prompt)
22
- text = response.text.strip().split("\n")
23
- posts = [t for t in text if len(t.strip()) > 0][:n]
24
  except Exception:
25
- source = "HuggingFace"
26
- base_posts = [
27
- f"{hashtag} is the best thing ever 🎉",
28
- f"I'm disappointed with {hashtag} 💔",
29
- f"Not sure how I feel about {hashtag} 🤔",
30
- f"Super excited about {hashtag} 🔥",
31
- f"People are talking about {hashtag} everywhere 🌍",
32
- f"{hashtag} totally failed expectations 😠",
33
- f"I love {hashtag}! It's amazing ❤️",
34
- ]
35
- posts = random.choices(base_posts, k=n)
36
-
37
- return posts, source
38
-
39
- # ========== SENTIMENT ANALYSIS ==========
40
- def analyze_posts(posts):
41
- results = sentiment_pipeline(posts)
42
- data = []
43
- for post, res in zip(posts, results):
44
- data.append({
45
- "Post": post,
46
- "Sentiment": res["label"],
47
- "Confidence": round(res["score"], 2)
48
- })
49
- return pd.DataFrame(data)
50
-
51
- # ========== STREAMLIT UI ==========
52
- st.set_page_config(
53
- page_title="AI Sentiment Analyzer",
54
- page_icon="📊",
55
- layout="wide",
56
- initial_sidebar_state="expanded"
57
- )
58
 
59
- st.markdown(
60
- """
61
- <style>
62
- body {
63
- background: linear-gradient(135deg, #0f2027, #203a43, #2c5364);
64
- color: white;
65
- }
66
- .stButton button {
67
- background: linear-gradient(45deg, #ff6b6b, #f7b733);
68
- color: white;
69
- font-weight: bold;
70
- border-radius: 12px;
71
- padding: 10px 20px;
72
- transition: 0.3s;
73
- }
74
- .stButton button:hover {
75
- transform: scale(1.05);
76
- background: linear-gradient(45deg, #36d1dc, #5b86e5);
77
- }
78
- .moving {
79
- animation: float 3s ease-in-out infinite;
80
- }
81
- @keyframes float {
82
- 0% { transform: translatey(0px); }
83
- 50% { transform: translatey(-10px); }
84
- 100% { transform: translatey(0px); }
 
 
 
 
 
 
 
85
  }
86
- </style>
87
- """,
88
- unsafe_allow_html=True
89
- )
90
 
91
- st.title("🚀 Social Media Sentiment Analyzer")
92
- st.markdown("### Stream posts Analyze moods • Visualize trends with animations")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- # Sidebar
95
- st.sidebar.header("⚙️ Controls")
96
- hashtag = st.sidebar.text_input("Enter Hashtag", "#gla")
97
- num_posts = st.sidebar.slider("Number of Posts", 5, 50, 20)
98
- vis_type = st.sidebar.selectbox("Choose Visualization", ["Bar", "Pie", "Line"])
 
 
 
 
 
 
 
99
 
100
- if st.sidebar.button("🔍 Run Analysis"):
101
- with st.spinner("Generating posts and analyzing sentiments..."):
102
- posts, source = generate_posts(hashtag, num_posts)
103
- df = analyze_posts(posts)
 
 
 
 
 
 
104
 
105
- st.success(f"✅ Posts generated by **{source}** | Total: {len(posts)}")
 
106
 
107
- st.dataframe(df, use_container_width=True)
 
 
 
 
 
 
 
 
 
 
108
 
109
- # ===== VISUALIZATIONS =====
110
- if vis_type == "Bar":
111
- fig = px.bar(df, x="Sentiment", color="Sentiment", title=f"Sentiment Distribution for {hashtag}", animation_frame=None)
112
- elif vis_type == "Pie":
113
- fig = px.pie(df, names="Sentiment", title=f"Sentiment Share for {hashtag}", hole=0.4)
114
- else:
115
- fig = px.line(df, y="Confidence", title=f"Confidence Trend for {hashtag}")
116
-
117
- st.plotly_chart(fig, use_container_width=True)
118
-
119
- # ===== Moving 3D-like floating animation =====
120
- st.markdown(
121
- f"""
122
- <div class="moving" style="font-size:24px; text-align:center; margin-top:30px;">
123
- 🌍 🔥 ❤️ 🎉 Trending Vibes around <b>{hashtag}</b>
124
- </div>
125
- """,
126
- unsafe_allow_html=True
127
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import random
3
+ import time
4
+ from typing import List, Dict
5
+
6
+ from flask import Flask, jsonify, request, render_template
7
+ from flask_cors import CORS
8
+
9
  import google.generativeai as genai
10
  from transformers import pipeline
11
 
12
+ # -----------------------
13
+ # Flask setup
14
+ # -----------------------
15
+ app = Flask(__name__, static_folder="static", template_folder="templates")
16
+ CORS(app)
17
 
18
+ # -----------------------
19
+ # Config & Environment
20
+ # -----------------------
21
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
22
+ if GOOGLE_API_KEY:
23
+ genai.configure(api_key=GOOGLE_API_KEY)
24
+
25
+ # Cap posts
26
+ MAX_POSTS = 50
27
+ DEFAULT_POSTS = 20
28
+
29
+ # -----------------------
30
+ # Sentiment Analyzer (HF)
31
+ # -----------------------
32
+ # Pin a specific model for stability (avoid the production warning)
33
+ SENTIMENT_MODEL = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
34
+ sentiment_analyzer = pipeline(
35
+ "sentiment-analysis",
36
+ model=SENTIMENT_MODEL,
37
+ device=-1 # CPU
38
+ )
39
 
40
+ # -----------------------
41
+ # Helpers
42
+ # -----------------------
43
+ def normalize_count(n: int) -> int:
44
  try:
45
+ n = int(n)
 
 
 
46
  except Exception:
47
+ n = DEFAULT_POSTS
48
+ n = max(1, min(MAX_POSTS, n))
49
+ return n
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ def parse_sentiment(label: str, score: float) -> Dict[str, str]:
52
+ # Standardize to POSITIVE / NEGATIVE / NEUTRAL (distilbert gives POSITIVE/NEGATIVE)
53
+ if label.upper() == "POSITIVE":
54
+ sentiment = "POSITIVE"
55
+ elif label.upper() == "NEGATIVE":
56
+ sentiment = "NEGATIVE"
57
+ else:
58
+ sentiment = "NEUTRAL"
59
+ return {"sentiment": sentiment, "score": float(score)}
60
+
61
+ def compute_aggregate(rows: List[Dict]) -> Dict:
62
+ pos = sum(1 for r in rows if r["sentiment"] == "POSITIVE")
63
+ neg = sum(1 for r in rows if r["sentiment"] == "NEGATIVE")
64
+ neu = sum(1 for r in rows if r["sentiment"] == "NEUTRAL")
65
+
66
+ total = max(1, len(rows))
67
+ pos_pct = round(100 * pos / total, 2)
68
+ neg_pct = round(100 * neg / total, 2)
69
+ neu_pct = round(100 * neu / total, 2)
70
+
71
+ # Rolling sentiment (simple EMA-like)
72
+ rolling = []
73
+ score_map = {"POSITIVE": 1.0, "NEUTRAL": 0.5, "NEGATIVE": 0.0}
74
+ alpha = 0.2
75
+ ema = 0.5
76
+ for r in rows:
77
+ ema = alpha * score_map[r["sentiment"]] + (1 - alpha) * ema
78
+ rolling.append(round(ema, 3))
79
+
80
+ return {
81
+ "counts": {"positive": pos, "negative": neg, "neutral": neu, "total": total},
82
+ "percent": {"positive": pos_pct, "negative": neg_pct, "neutral": neu_pct},
83
+ "rolling": rolling,
84
  }
 
 
 
 
85
 
86
+ # -----------------------
87
+ # Synthetic fallback posts (no external calls)
88
+ # -----------------------
89
+ FALLBACK_PATTERNS_POS = [
90
+ "Absolutely loving {tag} right now! 🔥",
91
+ "{tag} campaign is the best thing this season 🎉",
92
+ "I love {tag}! It's amazing ❤️",
93
+ "People are talking about {tag} everywhere 🌍",
94
+ "Super excited about {tag} 🙌",
95
+ ]
96
+ FALLBACK_PATTERNS_NEG = [
97
+ "{tag} totally failed expectations 😠",
98
+ "I'm disappointed with {tag} 💔",
99
+ "{tag} needs serious improvements…",
100
+ "Not impressed by {tag} this time 😕",
101
+ ]
102
+ FALLBACK_PATTERNS_NEU = [
103
+ "People are discussing {tag} a lot 🤔",
104
+ "Not sure how I feel about {tag} yet…",
105
+ "{tag} is trending — thoughts?",
106
+ "Mixed opinions around {tag}.",
107
+ ]
108
 
109
+ def make_fallback_posts(hashtag: str, n: int) -> List[str]:
110
+ tag = hashtag if hashtag.startswith("#") else f"#{hashtag}"
111
+ posts = []
112
+ for _ in range(n):
113
+ bucket = random.choices(
114
+ [FALLBACK_PATTERNS_POS, FALLBACK_PATTERNS_NEU, FALLBACK_PATTERNS_NEG],
115
+ weights=[0.4, 0.35, 0.25],
116
+ k=1
117
+ )[0]
118
+ txt = random.choice(bucket).format(tag=tag)
119
+ posts.append(txt)
120
+ return posts
121
 
122
+ # -----------------------
123
+ # Gemini generation
124
+ # -----------------------
125
+ def generate_with_gemini(hashtag: str, n: int) -> List[str]:
126
+ """
127
+ Generate up to n short social posts using Gemini 2.0 Flash.
128
+ Returns list of strings. If API missing or error occurs, raises Exception.
129
+ """
130
+ if not GOOGLE_API_KEY:
131
+ raise RuntimeError("GOOGLE_API_KEY not set")
132
 
133
+ model = genai.GenerativeModel("gemini-2.0-flash")
134
+ tag = hashtag if hashtag.startswith("#") else f"#{hashtag}"
135
 
136
+ prompt = f"""
137
+ You are generating short, natural social posts (Twitter/Instagram style) about the topic {tag}.
138
+ Rules:
139
+ - Return exactly {n} posts.
140
+ - One post per line.
141
+ - Each post under 120 characters.
142
+ - Use a mix of positive, neutral, and critical tones.
143
+ - Avoid any hate speech, harassment, or slurs.
144
+ - Do NOT include numbering like "1." or "-".
145
+ - Do NOT wrap in code blocks.
146
+ - Language: English.
147
 
148
+ Output format:
149
+ <post 1>
150
+ <post 2>
151
+ ...
152
+ <post {n}>
153
+ """
154
+
155
+ # Simple retry to avoid transient errors
156
+ tries = 2
157
+ for i in range(tries):
158
+ try:
159
+ r = model.generate_content(prompt)
160
+ text = (r.text or "").strip()
161
+ if not text:
162
+ raise RuntimeError("Empty response from Gemini")
163
+
164
+ lines = [ln.strip() for ln in text.split("\n") if ln.strip()]
165
+ # Keep only the first n lines; also handle if Gemini returns more or fewer lines
166
+ if len(lines) < n:
167
+ # pad with fallback to hit n
168
+ lines += make_fallback_posts(hashtag, n - len(lines))
169
+ posts = lines[:n]
170
+ return posts
171
+ except Exception as e:
172
+ if i == tries - 1:
173
+ raise
174
+ time.sleep(0.8) # brief backoff
175
+
176
+ # -----------------------
177
+ # API: analyze
178
+ # Request JSON:
179
+ # { "hashtag": "gla", "count": 30 }
180
+ # -----------------------
181
+ @app.route("/api/analyze", methods=["POST"])
182
+ def analyze():
183
+ data = request.get_json(silent=True) or {}
184
+ hashtag = (data.get("hashtag") or "").strip()
185
+ count = normalize_count(data.get("count") or DEFAULT_POSTS)
186
+
187
+ if not hashtag:
188
+ return jsonify({"error": "hashtag is required"}), 400
189
+
190
+ posts: List[Dict] = []
191
+ gemini_count = 0
192
+ fallback_count = 0
193
+
194
+ # Try Gemini first; if it fails, fall back fully.
195
+ try:
196
+ gemini_posts = generate_with_gemini(hashtag, count)
197
+ for p in gemini_posts:
198
+ posts.append({"text": p, "source": "gemini"})
199
+ gemini_count = len(gemini_posts)
200
+ except Exception:
201
+ fb = make_fallback_posts(hashtag, count)
202
+ for p in fb:
203
+ posts.append({"text": p, "source": "fallback"})
204
+ fallback_count = len(fb)
205
+
206
+ # Sentiment analysis
207
+ rows = []
208
+ for p in posts:
209
+ res = sentiment_analyzer(p["text"])[0] # {'label': 'POSITIVE', 'score': 0.99}
210
+ parsed = parse_sentiment(res["label"], res["score"])
211
+ rows.append({
212
+ "text": p["text"],
213
+ "source": p["source"],
214
+ "sentiment": parsed["sentiment"],
215
+ "score": parsed["score"],
216
+ })
217
+
218
+ agg = compute_aggregate(rows)
219
+
220
+ return jsonify({
221
+ "meta": {
222
+ "hashtag": hashtag if hashtag.startswith("#") else f"#{hashtag}",
223
+ "requested": count,
224
+ "generated_by": {
225
+ "gemini": gemini_count,
226
+ "fallback": fallback_count
227
+ },
228
+ "model": {
229
+ "generation": "gemini-2.0-flash" if gemini_count > 0 else "fallback-templates",
230
+ "sentiment": SENTIMENT_MODEL
231
+ }
232
+ },
233
+ "rows": rows,
234
+ "aggregate": agg
235
+ }), 200
236
+
237
+ # -----------------------
238
+ # UI Route
239
+ # -----------------------
240
+ @app.route("/", methods=["GET"])
241
+ def home():
242
+ return render_template("index.html")
243
+
244
+ # -----------------------
245
+ # Entrypoint
246
+ # -----------------------
247
+ if __name__ == "__main__":
248
+ port = int(os.getenv("PORT", "7860"))
249
+ app.run(host="0.0.0.0", port=port, debug=False)