File size: 7,734 Bytes
dcc24f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfd5cfe5-ea7e-49d8-9ef3-5d43bef5a0cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import json\n",
    "import re\n",
    "import random\n",
    "import time\n",
    "from pathlib import Path\n",
    "from tqdm.notebook import tqdm\n",
    "from mlx_lm import load, generate\n",
    "from collections import Counter\n",
    "\n",
    "\n",
    "# Cell 11: Load parsed emails from cache\n",
    "\n",
    "cache_path = PROJECT / \"data/parsed/emails.json\"\n",
    "\n",
    "with open(cache_path, 'r', encoding='utf-8') as f:\n",
    "    parsed_emails = json.load(f)\n",
    "\n",
    "print(f\"✅ Loaded {len(parsed_emails):,} emails from cache\")\n",
    "\n",
    "# Cell 12: Random sampling\n",
    "random.seed(42)\n",
    "\n",
    "# Pick 500 random emails\n",
    "sample_size = 500\n",
    "sample_emails = random.sample(parsed_emails,sample_size)\n",
    "\n",
    "print(f\"Total emails: {len(parsed_emails):,}\")\n",
    "print(f\"Sample size: {len(sample_emails)}\")\n",
    "\n",
    "# Preview one sample\n",
    "print(f\"\\n=== SAMPLE EMAIL #1 ===\")\n",
    "print(f\"Subject: {sample_emails[0]['subject']}\")\n",
    "print(f\"Sender: {sample_emails[0]['sender']}\")\n",
    "print(f\"Body: {sample_emails[0]['body'][:300]}...\")\n",
    "\n",
    "# Cell 13: Classification prompt template\n",
    "\n",
    "CLASSIFICATION_PROMPT = \"\"\"You are an email classifier. Analyze this email and categorize it.\n",
    "\n",
    "EMAIL:\n",
    "Subject: {subject}\n",
    "From: {sender}\n",
    "Body: {body}\n",
    "\n",
    "TASK:\n",
    "Classify this email into exactly ONE category.\n",
    "\n",
    "CATEGORIES:\n",
    "- finance: Banks, payments, transactions, investments, credit cards, loans, UPI, wallets\n",
    "- shopping: Orders, deliveries, purchases, e-commerce\n",
    "- social: Social networks, personal messages, invitations\n",
    "- work: Job-related, recruitment, office, meetings, projects\n",
    "- newsletter: Digests, subscriptions, blogs, articles\n",
    "- promotional: Marketing, offers, discounts, advertisements\n",
    "- other: Anything that doesn't fit above\n",
    "\n",
    "OUTPUT FORMAT (JSON only, no other text):\n",
    "{{\"category\": \"<category>\", \"confidence\": \"<high/medium/low>\", \"reason\": \"<brief 5-10 word reason>\"}}\n",
    "\"\"\"\n",
    "\n",
    "def build_prompt(email_data):\n",
    "    \"\"\"Build classification prompt for one email.\"\"\"\n",
    "    return CLASSIFICATION_PROMPT.format(\n",
    "        subject=email_data['subject'][:200],\n",
    "        sender=email_data['sender'][:100],\n",
    "        body=email_data['body'][:2000]\n",
    "    )\n",
    "\n",
    "# Test: See what prompt looks like\n",
    "test_prompt = build_prompt(sample_emails[0])\n",
    "print(f\"Prompt length: {len(test_prompt)} characters\")\n",
    "print(f\"\\n=== PROMPT PREVIEW ===\\n{test_prompt[:1000]}...\")\n",
    "\n",
    "# Cell 14: Load Phi-3 model\n",
    "model_path = str(PROJECT / \"models/base/phi3-mini\")\n",
    "\n",
    "print(\"Loading Phi-3 model...\")\n",
    "model, tokenizer = load(model_path)\n",
    "print(\"✅ Model loaded\")\n",
    "\n",
    "# Cell 15: Test classification on one email\n",
    "test_email = sample_emails[0]\n",
    "\n",
    "# Build prompt\n",
    "prompt = build_prompt(test_email)\n",
    "\n",
    "# Send to Phi-3\n",
    "print(\"Classifying email...\")\n",
    "print(f\"Subject: {test_email['subject'][:80]}...\")\n",
    "print(\"-\" * 50)\n",
    "\n",
    "response = generate(\n",
    "    model, \n",
    "    tokenizer, \n",
    "    prompt=prompt,\n",
    "    max_tokens=100,\n",
    "    verbose=False\n",
    ")\n",
    "\n",
    "print(f\"\\n=== PHI-3 RESPONSE ===\\n{response}\")\n",
    "\n",
    "# Cell 16: JSON extraction helper\n",
    "\n",
    "def extract_json(response):\n",
    "    \"\"\"Extract JSON object from LLM response.\"\"\"\n",
    "\n",
    "    # Find JSON pattern in response\n",
    "    match = re.search(r'\\{[^{}]*\\}', response)\n",
    "\n",
    "    if(match):\n",
    "          try:\n",
    "              return json.loads(match.group())\n",
    "          except json.JSONDecodeError:\n",
    "              return None\n",
    "    return None\n",
    "\n",
    "# Test on previous response\n",
    "parsed = extract_json(response)\n",
    "\n",
    "print(\"=== EXTRACTED JSON ===\")\n",
    "print(parsed)\n",
    "print(f\"\\nCategory: {parsed['category']}\")\n",
    "print(f\"Confidence: {parsed['confidence']}\")\n",
    "print(f\"Reason: {parsed['reason']}\")\n",
    "\n",
    "# Cell 17: Classify all sample emails\n",
    "results = []\n",
    "failed = 0\n",
    "\n",
    "print(f\"Classifying {len(sample_emails)} emails...\")\n",
    "print(\"Estimated time: ~5 minutes\\n\")\n",
    "\n",
    "start_time = time.time()\n",
    "\n",
    "for i, email_data in enumerate(tqdm(sample_emails, desc=\"Classifying\")):\n",
    "    try:\n",
    "        # Build prompt\n",
    "        prompt = build_prompt(email_data)\n",
    "        \n",
    "        # Get classification\n",
    "        response = generate(\n",
    "            model, \n",
    "            tokenizer, \n",
    "            prompt=prompt,\n",
    "            max_tokens=100,\n",
    "            verbose=False\n",
    "        )\n",
    "        \n",
    "        # Extract JSON\n",
    "        parsed = extract_json(response)\n",
    "        \n",
    "        if parsed:\n",
    "            results.append({\n",
    "                'id': email_data.get('id', i),\n",
    "                'subject': email_data['subject'],\n",
    "                'sender': email_data['sender'],\n",
    "                'category': parsed.get('category', 'other'),\n",
    "                'confidence': parsed.get('confidence', 'low'),\n",
    "                'reason': parsed.get('reason', '')\n",
    "            })\n",
    "        else:\n",
    "            failed += 1\n",
    "            \n",
    "    except Exception as e:\n",
    "        failed += 1\n",
    "        continue\n",
    "\n",
    "elapsed = time.time() - start_time\n",
    "\n",
    "print(f\"\\n✅ Classified: {len(results)}\")\n",
    "print(f\"❌ Failed: {failed}\")\n",
    "print(f\"⏱️ Time: {elapsed/60:.1f} minutes\")\n",
    "print(f\"⚡ Speed: {len(results)/elapsed:.1f} emails/sec\")\n",
    "\n",
    "# Cell 18: Category distribution\n",
    "\n",
    "categories = Counter([r['category'] for r in results])\n",
    "\n",
    "print(\"=== CATEGORY DISTRIBUTION ===\\n\")\n",
    "for category, count in categories.most_common():\n",
    "    pct = count / len(results) * 100\n",
    "    bar = \"\" * int(pct / 2)\n",
    "    print(f\"{category:12} {count:4} ({pct:5.1f}%) {bar}\")\n",
    "\n",
    "print(f\"\\n📊 Total classified: {len(results)}\")\n",
    "\n",
    "# Cell 19: Save classification results\n",
    "results_path = PROJECT / \"data/parsed/classification_results.json\"\n",
    "\n",
    "with open(results_path, 'w', encoding='utf-8') as f:\n",
    "    json.dump(results, f, ensure_ascii=False, indent=2)\n",
    "\n",
    "print(f\"✅ Saved {len(results)} results to {results_path}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}