themodmaker commited on
Commit
560a02a
·
verified ·
1 Parent(s): 58da5b4

Source code

Browse files
Files changed (1) hide show
  1. pestle.py +437 -0
pestle.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PESTLE MODEL - COMPLETE USAGE GUIDE WITH MODEL PERSISTENCE
3
+ ==========================================================
4
+
5
+ This script demonstrates:
6
+ 1. Training and saving the model
7
+ 2. Loading a saved model
8
+ 3. Making predictions with prompts
9
+ 4. Batch predictions
10
+ """
11
+
12
+ import pandas as pd
13
+ import numpy as np
14
+ import pickle
15
+ import json
16
+ from pathlib import Path
17
+ from datetime import datetime
18
+ from scipy.sparse import hstack, csr_matrix
19
+ import warnings
20
+ warnings.filterwarnings('ignore')
21
+
22
+ from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
23
+ from sklearn.model_selection import train_test_split
24
+ from sklearn.preprocessing import LabelEncoder
25
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
26
+ from sklearn.linear_model import LogisticRegression
27
+ from sklearn.naive_bayes import MultinomialNB
28
+ from sklearn.metrics import accuracy_score, classification_report
29
+
30
+
31
+ class PESTLEModel:
32
+ """Production-ready PESTLE classifier with save/load functionality"""
33
+
34
+ def __init__(self):
35
+ self.model = None
36
+ self.vectorizers = {}
37
+ self.label_encoder = LabelEncoder()
38
+ self.best_model_name = None
39
+ self.pestle_keywords = {
40
+ 'Political': ['government', 'election', 'policy', 'congress', 'senate',
41
+ 'president', 'legislation', 'vote', 'parliament', 'diplomacy'],
42
+ 'Economic': ['economy', 'market', 'stock', 'trade', 'gdp', 'inflation',
43
+ 'interest rate', 'unemployment', 'fed', 'revenue', 'profit'],
44
+ 'Social': ['healthcare', 'education', 'social', 'community', 'demographic',
45
+ 'population', 'immigration', 'diversity', 'equality', 'housing'],
46
+ 'Technological': ['technology', 'ai', 'artificial intelligence', 'innovation',
47
+ 'digital', 'cyber', 'data', 'software', 'internet', 'automation'],
48
+ 'Legal': ['law', 'court', 'legal', 'lawsuit', 'judge', 'attorney',
49
+ 'regulation', 'compliance', 'contract', 'patent', 'trial'],
50
+ 'Environmental': ['climate', 'environment', 'carbon', 'emission', 'pollution',
51
+ 'renewable', 'energy', 'sustainability', 'green', 'conservation']
52
+ }
53
+ self.metadata = {}
54
+
55
+ def train(self, csv_path='pestle_news_samples_6000_rows.csv'):
56
+ """Train the model from scratch"""
57
+ print("="*80)
58
+ print("TRAINING PESTLE MODEL".center(80))
59
+ print("="*80)
60
+
61
+ # Load data
62
+ print("\n1. Loading data...")
63
+ df = pd.read_csv(csv_path)
64
+ print(f" ✅ Loaded {len(df)} records")
65
+
66
+ # Prepare text features
67
+ print("\n2. Preparing features...")
68
+ df['text_features'] = (
69
+ df['Headline'].fillna('') + ' ' +
70
+ df['Description'].fillna('') + ' ' +
71
+ df['Topic_Tags'].fillna('').str.replace(',', ' ')
72
+ ).str.lower().str.replace(r'[^\w\s]', '', regex=True)
73
+
74
+ # Create keyword features
75
+ keyword_features = []
76
+ for _, row in df.iterrows():
77
+ text = row['text_features']
78
+ features = []
79
+ for category, keywords in self.pestle_keywords.items():
80
+ score = sum(1 for kw in keywords if kw in text) / len(keywords)
81
+ features.append(score)
82
+ keyword_features.append(features)
83
+
84
+ # TF-IDF vectorization
85
+ tfidf = TfidfVectorizer(
86
+ max_features=3000,
87
+ ngram_range=(1, 3),
88
+ stop_words='english',
89
+ min_df=2,
90
+ max_df=0.95
91
+ )
92
+ X_tfidf = tfidf.fit_transform(df['text_features'])
93
+ self.vectorizers['tfidf'] = tfidf
94
+ print(f" ✅ TF-IDF features: {X_tfidf.shape}")
95
+
96
+ # Combine features
97
+ X_combined = hstack([X_tfidf, csr_matrix(keyword_features)])
98
+
99
+ # Encode labels
100
+ y = self.label_encoder.fit_transform(df['PESTLE_Category'])
101
+
102
+ # Train-test split
103
+ print("\n3. Training models...")
104
+ X_train, X_test, y_train, y_test = train_test_split(
105
+ X_combined, y, test_size=0.2, random_state=42, stratify=y
106
+ )
107
+
108
+ # Train multiple models
109
+ models = {
110
+ 'Random Forest': RandomForestClassifier(
111
+ n_estimators=200, max_depth=30, random_state=42, n_jobs=-1
112
+ ),
113
+ 'Gradient Boosting': GradientBoostingClassifier(
114
+ n_estimators=150, learning_rate=0.1, random_state=42
115
+ ),
116
+ 'Logistic Regression': LogisticRegression(
117
+ max_iter=1000, C=1.0, class_weight='balanced', random_state=42
118
+ )
119
+ }
120
+
121
+ best_score = 0
122
+ best_model = None
123
+ best_name = None
124
+
125
+ for name, model in models.items():
126
+ model.fit(X_train, y_train)
127
+ y_pred = model.predict(X_test)
128
+ accuracy = accuracy_score(y_test, y_pred)
129
+ print(f" {name}: {accuracy:.4f}")
130
+
131
+ if accuracy > best_score:
132
+ best_score = accuracy
133
+ best_model = model
134
+ best_name = name
135
+
136
+ self.model = best_model
137
+ self.best_model_name = best_name
138
+
139
+ # Store metadata
140
+ self.metadata = {
141
+ 'model_type': best_name,
142
+ 'accuracy': best_score,
143
+ 'trained_date': datetime.now().isoformat(),
144
+ 'n_samples': len(df),
145
+ 'categories': self.label_encoder.classes_.tolist()
146
+ }
147
+
148
+ print(f"\n🏆 Best Model: {best_name} (Accuracy: {best_score:.4f})")
149
+ print("\n Category Performance:")
150
+ report = classification_report(y_test, self.model.predict(X_test),
151
+ target_names=self.label_encoder.classes_,
152
+ output_dict=True)
153
+ for cat in self.label_encoder.classes_:
154
+ f1 = report[cat]['f1-score']
155
+ print(f" - {cat}: F1={f1:.3f}")
156
+
157
+ return True
158
+
159
+ def save(self, model_name="pestle_model"):
160
+ """Save model to disk"""
161
+ print(f"\n{'='*80}")
162
+ print(f"SAVING MODEL: {model_name}".center(80))
163
+ print("="*80)
164
+
165
+ model_dir = Path("pestle_models") / model_name
166
+ model_dir.mkdir(parents=True, exist_ok=True)
167
+
168
+ # Save model
169
+ with open(model_dir / "model.pkl", 'wb') as f:
170
+ pickle.dump(self.model, f)
171
+ print(f"✅ Model saved")
172
+
173
+ # Save vectorizers
174
+ with open(model_dir / "vectorizers.pkl", 'wb') as f:
175
+ pickle.dump(self.vectorizers, f)
176
+ print(f"✅ Vectorizers saved")
177
+
178
+ # Save label encoder
179
+ with open(model_dir / "label_encoder.pkl", 'wb') as f:
180
+ pickle.dump(self.label_encoder, f)
181
+ print(f"✅ Label encoder saved")
182
+
183
+ # Save keywords
184
+ with open(model_dir / "keywords.pkl", 'wb') as f:
185
+ pickle.dump(self.pestle_keywords, f)
186
+ print(f"✅ Keywords saved")
187
+
188
+ # Save metadata
189
+ with open(model_dir / "metadata.json", 'w') as f:
190
+ json.dump(self.metadata, f, indent=2)
191
+ print(f"✅ Metadata saved")
192
+
193
+ print(f"\n📁 Model saved to: {model_dir.absolute()}")
194
+ return str(model_dir)
195
+
196
+ def load(self, model_name="pestle_model"):
197
+ """Load model from disk"""
198
+ print(f"\n{'='*80}")
199
+ print(f"LOADING MODEL: {model_name}".center(80))
200
+ print("="*80)
201
+
202
+ model_dir = Path("pestle_models") / model_name
203
+
204
+ if not model_dir.exists():
205
+ raise FileNotFoundError(f"Model directory not found: {model_dir}")
206
+
207
+ # Load components
208
+ with open(model_dir / "model.pkl", 'rb') as f:
209
+ self.model = pickle.load(f)
210
+ print("✅ Model loaded")
211
+
212
+ with open(model_dir / "vectorizers.pkl", 'rb') as f:
213
+ self.vectorizers = pickle.load(f)
214
+ print("✅ Vectorizers loaded")
215
+
216
+ with open(model_dir / "label_encoder.pkl", 'rb') as f:
217
+ self.label_encoder = pickle.load(f)
218
+ print("✅ Label encoder loaded")
219
+
220
+ with open(model_dir / "keywords.pkl", 'rb') as f:
221
+ self.pestle_keywords = pickle.load(f)
222
+ print("✅ Keywords loaded")
223
+
224
+ with open(model_dir / "metadata.json", 'r') as f:
225
+ self.metadata = json.load(f)
226
+ print("✅ Metadata loaded")
227
+
228
+ print(f"\n📊 Model Info:")
229
+ print(f" Type: {self.metadata.get('model_type', 'Unknown')}")
230
+ print(f" Accuracy: {self.metadata.get('accuracy', 0):.4f}")
231
+ print(f" Trained: {self.metadata.get('trained_date', 'Unknown')}")
232
+ print(f" Categories: {', '.join(self.metadata.get('categories', []))}")
233
+
234
+ return True
235
+
236
+ def predict(self, text, show_probabilities=True):
237
+ """Predict PESTLE category for text"""
238
+ if self.model is None:
239
+ raise ValueError("Model not loaded. Call train() or load() first.")
240
+
241
+ # Preprocess text
242
+ text_processed = text.lower()
243
+ text_processed = ''.join(c for c in text_processed if c.isalnum() or c.isspace())
244
+
245
+ # Extract TF-IDF features
246
+ X_tfidf = self.vectorizers['tfidf'].transform([text_processed])
247
+
248
+ # Extract keyword features
249
+ keyword_features = []
250
+ for category, keywords in self.pestle_keywords.items():
251
+ score = sum(1 for kw in keywords if kw in text_processed) / len(keywords)
252
+ keyword_features.append(score)
253
+
254
+ # Combine features
255
+ X_combined = hstack([X_tfidf, csr_matrix([keyword_features])])
256
+
257
+ # Predict
258
+ prediction = self.model.predict(X_combined)[0]
259
+ predicted_category = self.label_encoder.inverse_transform([prediction])[0]
260
+
261
+ result = {'category': predicted_category}
262
+
263
+ if show_probabilities and hasattr(self.model, 'predict_proba'):
264
+ probabilities = self.model.predict_proba(X_combined)[0]
265
+ prob_dict = {
266
+ cat: float(prob)
267
+ for cat, prob in zip(self.label_encoder.classes_, probabilities)
268
+ }
269
+ result['probabilities'] = prob_dict
270
+ result['confidence'] = float(max(probabilities))
271
+
272
+ return result
273
+
274
+ def predict_batch(self, texts):
275
+ """Predict categories for multiple texts"""
276
+ results = []
277
+ for text in texts:
278
+ results.append(self.predict(text, show_probabilities=True))
279
+ return results
280
+
281
+
282
+ # =============================================================================
283
+ # USAGE EXAMPLES
284
+ # =============================================================================
285
+
286
+ def example_1_train_and_save():
287
+ """Example 1: Train a new model and save it"""
288
+ print("\n" + "="*80)
289
+ print("EXAMPLE 1: TRAIN AND SAVE MODEL".center(80))
290
+ print("="*80)
291
+
292
+ model = PESTLEModel()
293
+ model.train('pestle_news_samples_6000_rows.csv')
294
+ model.save("pestle_model")
295
+
296
+ print("\n✅ Model trained and saved successfully!")
297
+
298
+
299
+ def example_2_load_and_predict():
300
+ """Example 2: Load saved model and make predictions"""
301
+ print("\n" + "="*80)
302
+ print("EXAMPLE 2: LOAD MODEL AND PREDICT".center(80))
303
+ print("="*80)
304
+
305
+ # Load model
306
+ model = PESTLEModel()
307
+ model.load("pestle_model")
308
+
309
+ # Test prompts
310
+ test_prompts = [
311
+ "Congress passes new healthcare reform bill",
312
+ "Stock market reaches all-time high amid economic growth",
313
+ "New AI technology revolutionizes manufacturing",
314
+ "Supreme Court ruling on environmental regulations",
315
+ "Rising sea levels threaten coastal communities",
316
+ "Social media platforms face data privacy concerns"
317
+ ]
318
+
319
+ print("\n" + "="*80)
320
+ print("PREDICTIONS".center(80))
321
+ print("="*80)
322
+
323
+ for i, prompt in enumerate(test_prompts, 1):
324
+ result = model.predict(prompt)
325
+ print(f"\n{i}. Text: {prompt}")
326
+ print(f" Category: {result['category']}")
327
+ print(f" Confidence: {result['confidence']:.2%}")
328
+ print(f" Top 3 Probabilities:")
329
+ sorted_probs = sorted(result['probabilities'].items(),
330
+ key=lambda x: x[1], reverse=True)[:3]
331
+ for cat, prob in sorted_probs:
332
+ print(f" - {cat}: {prob:.2%}")
333
+
334
+
335
+ def example_3_interactive_mode():
336
+ """Example 3: Interactive prediction mode"""
337
+ print("\n" + "="*80)
338
+ print("EXAMPLE 3: INTERACTIVE MODE".center(80))
339
+ print("="*80)
340
+
341
+ model = PESTLEModel()
342
+
343
+ # Try to load existing model, otherwise train new one
344
+ try:
345
+ model.load("pestle_model")
346
+ except FileNotFoundError:
347
+ print("\n⚠️ No saved model found. Training new model...")
348
+ model.train('pestle_news_samples_6000_rows.csv')
349
+ model.save("pestle_model")
350
+
351
+ print("\n" + "="*80)
352
+ print("Enter text to classify (or 'quit' to exit)".center(80))
353
+ print("="*80)
354
+
355
+ while True:
356
+ text = input("\n📝 Enter text: ").strip()
357
+
358
+ if text.lower() in ['quit', 'exit', 'q']:
359
+ print("\n👋 Goodbye!")
360
+ break
361
+
362
+ if not text:
363
+ print("⚠️ Please enter some text")
364
+ continue
365
+
366
+ result = model.predict(text)
367
+ print(f"\n🎯 Predicted Category: {result['category']}")
368
+ print(f"📊 Confidence: {result['confidence']:.2%}")
369
+
370
+
371
+ def example_4_batch_prediction():
372
+ """Example 4: Batch prediction with export"""
373
+ print("\n" + "="*80)
374
+ print("EXAMPLE 4: BATCH PREDICTION".center(80))
375
+ print("="*80)
376
+
377
+ model = PESTLEModel()
378
+ model.load("pestle_model")
379
+
380
+ # Sample batch data
381
+ batch_texts = [
382
+ "Federal Reserve raises interest rates",
383
+ "Climate change summit reaches agreement",
384
+ "Tech giant faces antitrust lawsuit",
385
+ "New immigration policy announced",
386
+ "Breakthrough in quantum computing",
387
+ "Healthcare costs continue to rise"
388
+ ]
389
+
390
+ print(f"\nProcessing {len(batch_texts)} texts...")
391
+ results = model.predict_batch(batch_texts)
392
+
393
+ # Create DataFrame
394
+ df_results = pd.DataFrame({
395
+ 'Text': batch_texts,
396
+ 'Category': [r['category'] for r in results],
397
+ 'Confidence': [r['confidence'] for r in results]
398
+ })
399
+
400
+ print("\n" + "="*80)
401
+ print("BATCH RESULTS".center(80))
402
+ print("="*80)
403
+ print(df_results.to_string(index=False))
404
+
405
+ # Save to CSV
406
+ output_file = "batch_predictions.csv"
407
+ df_results.to_csv(output_file, index=False)
408
+ print(f"\n✅ Results saved to: {output_file}")
409
+
410
+
411
+ # =============================================================================
412
+ # MAIN EXECUTION
413
+ # =============================================================================
414
+
415
+ if __name__ == "__main__":
416
+ print("\n" + "="*80)
417
+ print("PESTLE MODEL - USAGE GUIDE".center(80))
418
+ print("="*80)
419
+ print("\nChoose an example to run:")
420
+ print("1. Train and save a new model")
421
+ print("2. Load model and make predictions")
422
+ print("3. Interactive prediction mode")
423
+ print("4. Batch prediction with export")
424
+
425
+ choice = input("\nEnter choice (1-4): ").strip()
426
+
427
+ if choice == '1':
428
+ example_1_train_and_save()
429
+ elif choice == '2':
430
+ example_2_load_and_predict()
431
+ elif choice == '3':
432
+ example_3_interactive_mode()
433
+ elif choice == '4':
434
+ example_4_batch_prediction()
435
+ else:
436
+ print("Invalid choice. Running example 1...")
437
+ example_1_train_and_save()