subashree commited on
Commit
824324e
·
verified ·
1 Parent(s): 300c4ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +300 -1
app.py CHANGED
@@ -1,2 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
- exec(os.environ.get('app'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from openai import OpenAI
4
+ import requests
5
+ import certifi
6
+ import tiktoken
7
+ from tiktoken import get_encoding
8
+ from pinecone import Pinecone, ServerlessSpec
9
+ import time
10
+
11
+ # Configuration
12
  import os
13
+ from dotenv import load_dotenv
14
+
15
+ # Load environment variables from .env file
16
+ load_dotenv()
17
+
18
+ # Configuration
19
+ OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
20
+ PINECONE_API_KEY = "9c097a58-6008-409a-859a-668a002320f6"
21
+ INDEX_NAME = "gradient-cyber"
22
+ BATCH_SIZE = 100
23
+ MAX_RESULTS = 1000
24
+
25
+ # Initialize OpenAI
26
+ client = OpenAI(api_key=OPENAI_API_KEY)
27
+
28
+ # Initialize Pinecone
29
+ pc = Pinecone(api_key=PINECONE_API_KEY)
30
+
31
+ # Check if the index already exists before creating it
32
+ if INDEX_NAME not in pc.list_indexes().names():
33
+ pc.create_index(
34
+ name=INDEX_NAME,
35
+ dimension=1536,
36
+ metric='cosine',
37
+ spec=ServerlessSpec(cloud='aws', region='us-east-1')
38
+ )
39
+ index = pc.Index(INDEX_NAME)
40
+
41
+ # Define helper functions
42
+ def truncate_text(text, max_tokens):
43
+ tokenizer = get_encoding("gpt2")
44
+ tokens = tokenizer.encode(text)
45
+ return tokenizer.decode(tokens[:max_tokens])
46
+
47
+ def generate_embedding(text):
48
+ max_retries = 3
49
+ for attempt in range(max_retries):
50
+ try:
51
+ response = client.embeddings.create(
52
+ model="text-embedding-ada-002",
53
+ input=text
54
+ )
55
+ return response.data[0].embedding, response.usage.total_tokens
56
+ except Exception as e:
57
+ if attempt == max_retries - 1:
58
+ st.error(f"Error creating embedding after {max_retries} attempts: {str(e)}")
59
+ return None, 0
60
+ time.sleep(2 ** attempt) # Exponential backoff
61
+
62
+ def upsert_in_batches(index, vectors, batch_size=100):
63
+ batches = [vectors[i:i + batch_size] for i in range(0, len(vectors), batch_size)]
64
+ for batch in batches:
65
+ try:
66
+ index.upsert(vectors=batch, namespace="ns1")
67
+ except Exception as e:
68
+ st.error(f"Error upserting batch: {e}")
69
+
70
+ def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
71
+ encoding = tiktoken.get_encoding(encoding_name)
72
+ num_tokens = len(encoding.encode(string))
73
+ return num_tokens
74
+
75
+ def semantic_similarity(text1, text2):
76
+ embedding1, _ = generate_embedding(text1)
77
+ embedding2, _ = generate_embedding(text2)
78
+ if embedding1 is None or embedding2 is None:
79
+ return 0
80
+ return sum(a*b for a, b in zip(embedding1, embedding2))
81
+
82
+ def expand_query(original_query):
83
+ try:
84
+ expansion_prompt = f"Expand the following query into 3-5 related questions or terms: '{original_query}'"
85
+
86
+ response = client.chat.completions.create(
87
+ model="gpt-4o-mini",
88
+ messages=[{"role": "user", "content": expansion_prompt}],
89
+ max_tokens=100,
90
+ temperature=0.7
91
+ )
92
+ return original_query + " " + response.choices[0].message.content
93
+ except Exception as e:
94
+ st.error(f"Error in query expansion: {str(e)}")
95
+ return original_query # Return original query if expansion fails
96
+
97
+ def truncate_context(context, max_tokens=14000):
98
+ encoding = tiktoken.get_encoding("cl100k_base")
99
+ encoded = encoding.encode(context)
100
+ truncated = encoded[:max_tokens]
101
+ return encoding.decode(truncated)
102
+
103
+ # Streamlit UI for file upload
104
+ st.title("Gradient-cyber")
105
+ uploaded_file = st.file_uploader("Upload an Excel file", type=["xlsx"])
106
+
107
+ if uploaded_file is not None:
108
+ st.write("File uploaded successfully!")
109
+ # Load Excel file
110
+ df = pd.read_excel(uploaded_file)
111
+ st.write("Excel file loaded:")
112
+ st.write(df.head())
113
+
114
+ # Concatenate text from all columns for each row into a readable sentence
115
+ def create_meaningful_sentence(row):
116
+ return '. '.join([f"{col.replace('_', ' ')}: {row[col]}" for col in df.columns])
117
+
118
+ df['combined_text'] = df.apply(create_meaningful_sentence, axis=1)
119
+ st.write("Columns concatenated into meaningful sentences:")
120
+ st.write(df[['combined_text']].head())
121
+
122
+ vectors = []
123
+ # Process each row in the DataFrame
124
+ total_tokens_used = 0
125
+ total_requests = 0
126
+ for i, row in df.iterrows():
127
+ text = row['combined_text']
128
+ # Truncate text to fit within the model's maximum context length
129
+ text = truncate_text(text, max_tokens=8192)
130
+ embedding, tokens_used = generate_embedding(text)
131
+ if embedding is not None:
132
+ total_tokens_used += tokens_used
133
+ total_requests += 1
134
+
135
+ # Truncate text fields to reduce metadata size
136
+ def truncate_field(field, max_length=500):
137
+ return str(field)[:max_length] if not pd.isna(field) else ''
138
+
139
+ # Prepare metadata with handling NaN values and converting to string
140
+ metadata = {
141
+ "ID": truncate_field(row['ID']),
142
+ "eventDtgTime": truncate_field(row['eventDtgTime']),
143
+ "alerts": truncate_field(row.get('alerts', '')),
144
+ "displayTitle": truncate_field(row['displayTitle']),
145
+ "instantAnalytics": truncate_field(row.get('instantAnalytics', '')),
146
+ "detailedText": truncate_field(row.get('detailedText', '')),
147
+ "msgPrecs": truncate_field(row.get('msgPrecs', '')),
148
+ "unit": truncate_field(row.get('unit', '')),
149
+ "size": truncate_field(row.get('size', '')),
150
+ "embedHtml": truncate_field(row.get('embedHtml', '')),
151
+ "dataSources": truncate_field(row.get('dataSources', '')),
152
+ "snippetText": truncate_field(row.get('snippetText', '')),
153
+ "contentLink": truncate_field(row.get('contentLink', '')),
154
+ "description": truncate_field(row.get('description', '')),
155
+ "imageDescription": truncate_field(row.get('imageDescription', '')),
156
+ "reportSummary": truncate_field(row.get('reportSummary', '')),
157
+ "authorName": truncate_field(row.get('authorName', '')),
158
+ "timeReportCompleted": truncate_field(row.get('timeReportCompleted', '')),
159
+ "attachment": truncate_field(row.get('attachment', '')),
160
+ "latitude": truncate_field(row.get('latitude', '')),
161
+ "securityLevels": truncate_field(row.get('securityLevels', '')),
162
+ "imagereSourceLink": truncate_field(row.get('imagereSourceLink', '')),
163
+ "eventDtg": truncate_field(row.get('eventDtg', '')),
164
+ "status": truncate_field(row.get('status', '')),
165
+ "users": truncate_field(row.get('users', '')),
166
+ "name": truncate_field(row.get('name', '')),
167
+ "sessions": truncate_field(row.get('sessions', '')),
168
+ "fiscalStatus": truncate_field(row.get('fiscalStatus', '')),
169
+ "sentimentSummary": truncate_field(row.get('sentimentSummary', '')),
170
+ "sourceOrg": truncate_field(row.get('sourceOrg', '')),
171
+ "dateCreated": truncate_field(row.get('dateCreated', '')),
172
+ "active": truncate_field(row.get('active', '')),
173
+ "responseSummary": truncate_field(row.get('responseSummary', '')),
174
+ "comparisonCommunitiesCountries": truncate_field(row.get('comparisonCommunitiesCountries', '')),
175
+ "activity": truncate_field(row.get('activity', '')),
176
+ "applications": truncate_field(row.get('applications', '')),
177
+ "url": truncate_field(row.get('url', '')),
178
+ "timeZones": truncate_field(row.get('timeZones', '')),
179
+ "location": truncate_field(row.get('location', '')),
180
+ "longitude": truncate_field(row.get('longitude', '')),
181
+ "dateModified": truncate_field(row.get('dateModified', '')),
182
+ "pedigrees": truncate_field(row.get('pedigrees', '')),
183
+ "gistComment": truncate_field(row.get('gistComment', '')),
184
+ "tag": truncate_field(row.get('tag', '')),
185
+ "geoCode": truncate_field(row.get('geoCode', '')),
186
+ "time": truncate_field(row.get('time', '')),
187
+ "timeReportRouted": truncate_field(row.get('timeReportRouted', '')),
188
+ "rteToOrg": truncate_field(row.get('rteToOrg', '')),
189
+ "copyReportToOrg": truncate_field(row.get('copyReportToOrg', '')),
190
+ "sourceOrganization": truncate_field(row.get('sourceOrganization', '')),
191
+ "coordinates": truncate_field(row.get('coordinates', '')),
192
+ "image1": truncate_field(row.get('image1', '')),
193
+ "image2": truncate_field(row.get('image2', '')),
194
+ "image3": truncate_field(row.get('image3', '')),
195
+ "image4": truncate_field(row.get('image4', '')),
196
+ "image5": truncate_field(row.get('image5', '')),
197
+ "numEmailsSent": truncate_field(row.get('numEmailsSent', '')),
198
+ "lastEmailDate": truncate_field(row.get('lastEmailDate', '')),
199
+ "reportDtg": truncate_field(row.get('reportDtg', '')),
200
+ "metadata": truncate_field(row.get('metadata', '')),
201
+ "eventOrganizations": truncate_field(row.get('eventOrganizations', '')),
202
+ "classification": truncate_field(row.get('classification', '')),
203
+ "assetIPs": truncate_field(row.get('assetIPs', '')),
204
+ "sitrepTemplate": truncate_field(row.get('sitrepTemplate', '')),
205
+ "industry": truncate_field(row.get('industry', '')),
206
+ "networkSegmentList": truncate_field(row.get('networkSegmentList', '')),
207
+ "approvedDate": truncate_field(row.get('approvedDate', '')),
208
+ "incident": truncate_field(row.get('incident', '')),
209
+ "sendEmail": truncate_field(row.get('sendEmail', '')),
210
+ "newFormat": truncate_field(row.get('newFormat', '')),
211
+ "duMapping": truncate_field(row.get('duMapping', '')),
212
+ "jsonTag": truncate_field(row.get('jsonTag', '')),
213
+ "createdFrom": truncate_field(row.get('createdFrom', '')),
214
+ "integrationData": truncate_field(row.get('integrationData', '')),
215
+ "mtti": truncate_field(row.get('mtti', '')),
216
+ "mttd": truncate_field(row.get('mttd', '')),
217
+ "mttr": truncate_field(row.get('mttr', '')),
218
+ "oldEventDate": truncate_field(row.get('oldEventDate', '')),
219
+ "org_event_name": truncate_field(row.get('org_event_name', '')),
220
+ "combined_text": text # Add combined text to metadata
221
+ }
222
+
223
+ vectors.append({'id': str(row['ID']), 'values': embedding, 'metadata': metadata})
224
+
225
+ if vectors:
226
+ upsert_in_batches(index, vectors, BATCH_SIZE)
227
+ st.success(f"Data successfully uploaded to Pinecone.")
228
+ st.info(f"Total tokens used: {total_tokens_used}")
229
+ st.info(f"Total requests made: {total_requests}")
230
+ else:
231
+ st.warning("No embeddings were generated.")
232
+
233
+ # Query input and response
234
+ query = st.text_input("Enter your query:")
235
+
236
+ if query:
237
+ try:
238
+ expanded_query = expand_query(query)
239
+ query_embedding, _ = generate_embedding(expanded_query)
240
+ if query_embedding is not None:
241
+ # Perform the search in Pinecone
242
+ results = index.query(
243
+ namespace="ns1",
244
+ vector=query_embedding,
245
+ top_k=50,
246
+ include_metadata=True
247
+ )
248
+
249
+ # Semantic filtering
250
+ filtered_results = sorted(
251
+ results['matches'],
252
+ key=lambda x: semantic_similarity(query, x['metadata']['combined_text']),
253
+ reverse=True
254
+ )[:10] # Reduced from 20 to 15 to further limit context size
255
+
256
+ # Prepare context for GPT
257
+ context = "\n".join([
258
+ f"ID: {match['id']}\n" +
259
+ f"Event Date/Time: {match['metadata'].get('eventDtgTime', 'N/A')}\n" +
260
+ f"Display Title: {match['metadata'].get('displayTitle', 'N/A')}\n" +
261
+ f"Status: {match['metadata'].get('status', 'N/A')}\n" +
262
+ f"Combined Text: {match['metadata'].get('combined_text', 'N/A')}\n" +
263
+ "---"
264
+ for match in filtered_results
265
+ ])
266
+
267
+ # Truncate the context
268
+ truncated_context = truncate_context(context)
269
+
270
+ # Prepare the prompt for GPT
271
+ system_prompt = """Core Capabilities: Expert Knowledge on SITREPs: Understand and explain SITREP components "
272
+ "like threat analysis, incident summaries, and risk assessments. Flexible Query Handling: "
273
+ "Interpret and respond to diverse queries, supporting access to data by date, theme, severity, etc. "
274
+ "Data Retrieval: Provide specific information such as incident dates, threat actors, and mitigation strategies. "
275
+ "Offer summaries or detailed reports based on user needs. Analytical Engagement: Engage in discussions, offering insights and hypotheses. "
276
+ "Support user analysis of cyber threats and incidents. Interaction Guidelines: Understanding Queries: "
277
+ "Use NLP to interpret user queries and ask clarifying questions if needed. Providing Context: Give context to help users understand the relevance and implications of information. "
278
+ "Customizing Detail: Adjust detail levels based on user preferences, providing summaries or deep dives. "
279
+ "Allow users to specify the format of information (e.g., bullet points, detailed paragraphs, tables). "
280
+ "Encouraging Exploration: Suggest related queries and additional information. Provide thorough explanations to support learning and decision-making."""
281
+ user_prompt = f"""Query: {query}
282
+ Relevant Information:
283
+ {truncated_context}
284
+ Provide a clear, concise, and comprehensive answer. Synthesize information from multiple entries if necessary. Cite specific details and examples when applicable. If information is missing, state what is known and what remains uncertain."""
285
+
286
+ response = client.chat.completions.create(
287
+ model="gpt-4o-mini",
288
+ messages=[
289
+ {"role": "system", "content": system_prompt},
290
+ {"role": "user", "content": user_prompt}
291
+ ],
292
+ max_tokens=1000,
293
+ temperature=0.7
294
+ )
295
+ answer = response.choices[0].message.content
296
+ st.write("Answer to your query:")
297
+ st.write(answer)
298
+ else:
299
+ st.error("Failed to generate query embedding.")
300
+ except Exception as e:
301
+ st.error(f"Error processing query: {str(e)}")