Dua Rajper commited on
Commit
24628cd
·
verified ·
1 Parent(s): bcb5008

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -7
app.py CHANGED
@@ -21,7 +21,7 @@ if GOOGLE_API_KEY:
21
  genai.configure(api_key=GOOGLE_API_KEY)
22
  else:
23
  st.error(
24
- "Google AI Studio API key not found. Please add it to your .env file. "
25
  "You can obtain an API key from https://makersuite.google.com/."
26
  )
27
  st.stop()
@@ -40,11 +40,19 @@ with st.sidebar:
40
  st.subheader("Key Concepts:")
41
  st.markdown(
42
  """
43
- - **Embeddings**: Numerical representations of text, capturing semantic meaning.
44
- - **Vector Databases**: Databases optimized for storing and querying vectors (simulated here).
45
- - **Retrieval Augmented Generation (RAG)**: Combining retrieval with LLM generation.
46
- - **Cosine Similarity**: A measure of similarity between two vectors.
47
- - **Neural Networks**: Using embeddings as input for classification.
 
 
 
 
 
 
 
 
48
  """
49
  )
50
 
@@ -61,4 +69,222 @@ def display_response(response: Any) -> None:
61
  else:
62
  st.error("Failed to generate a response.")
63
 
64
- def generate_embeddings(texts: List[str], model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  genai.configure(api_key=GOOGLE_API_KEY)
22
  else:
23
  st.error(
24
+ "Google AI Studio API key not found. Please add it to your .env file. "
25
  "You can obtain an API key from https://makersuite.google.com/."
26
  )
27
  st.stop()
 
40
  st.subheader("Key Concepts:")
41
  st.markdown(
42
  """
43
+ - **Embeddings**: Numerical representations of text, capturing semantic meaning.
44
+ - **Vector Databases**: Databases optimized for storing and querying vectors (simulated here).
45
+ - **Retrieval Augmented Generation (RAG)**: Combining retrieval with LLM generation.
46
+ - **Cosine Similarity**: A measure of similarity between two vectors.
47
+ - **Neural Networks**: Using embeddings as input for classification.
48
+ """
49
+ )
50
+ st.subheader("Whitepaper Insights")
51
+ st.markdown(
52
+ """
53
+ - Efficient similarity search using vector indexes (e.g., ANN).
54
+ - Handling large datasets and scalability considerations.
55
+ - Applications of embeddings: search, recommendation, classification, etc.
56
  """
57
  )
58
 
 
69
  else:
70
  st.error("Failed to generate a response.")
71
 
72
+ def generate_embeddings(texts: List[str], model_name: str = "models/embedding-001") -> Optional[List[List[float]]]:
73
+ """Generates embeddings for a list of texts using a specified model.
74
+ Args:
75
+ texts: List of text strings.
76
+ model_name: Name of the embedding model.
77
+ Returns:
78
+ List of embeddings (list of floats) or None on error.
79
+ """
80
+ try:
81
+ # Use the embedding model directly
82
+ embeddings = []
83
+ for text in texts:
84
+ result = genai.embed_content(
85
+ model=model_name,
86
+ content=text,
87
+ task_type="retrieval_document" # or "retrieval_query" for queries
88
+ )
89
+ embeddings.append(result['embedding'])
90
+ return embeddings
91
+ except Exception as e:
92
+ st.error(f"Error generating embeddings with model '{model_name}': {e}")
93
+ return None
94
+
95
+ def generate_with_retry(prompt: str, model_name: str, generation_config: genai.types.GenerationConfig, max_retries: int = 3, delay: int = 5) -> Any:
96
+ """Generates content with retry logic and error handling."""
97
+ for i in range(max_retries):
98
+ try:
99
+ model = genai.GenerativeModel(model_name)
100
+ response = model.generate_content(prompt, generation_config=generation_config)
101
+ return response
102
+ except Exception as e:
103
+ error_message = str(e)
104
+ st.warning(f"Error during generation (attempt {i + 1}/{max_retries}): {error_message}")
105
+ if "404" in error_message and "not found" in error_message:
106
+ st.error(
107
+ f"Model '{model_name}' is not available or not supported. Please select a different model."
108
+ )
109
+ return None
110
+ elif i < max_retries - 1:
111
+ st.info(f"Retrying in {delay} seconds...")
112
+ time.sleep(delay)
113
+ else:
114
+ st.error(f"Failed to generate content after {max_retries} attempts. Please check your prompt and model.")
115
+ return None
116
+ return None
117
+
118
+ def calculate_similarity(embedding1: List[float], embedding2: List[float]) -> float:
119
+ """Calculates the cosine similarity between two embeddings."""
120
+ return cosine_similarity(np.array(embedding1).reshape(1, -1), np.array(embedding2).reshape(1, -1))[0][0]
121
+
122
+ def create_and_train_model(
123
+ embeddings: List[List[float]],
124
+ labels: List[int],
125
+ num_classes: int,
126
+ epochs: int,
127
+ batch_size: int,
128
+ learning_rate: float,
129
+ optimizer_str: str
130
+ ) -> tf.keras.Model:
131
+ """Creates and trains a neural network for classification."""
132
+ model = Sequential([
133
+ Input(shape=(len(embeddings[0]),),
134
+ Dense(64, activation='relu'),
135
+ Dense(32, activation='relu'),
136
+ Dense(num_classes, activation='softmax')
137
+ ])
138
+
139
+ if optimizer_str.lower() == 'adam':
140
+ optimizer = Adam(learning_rate=learning_rate)
141
+ elif optimizer_str.lower() == 'sgd':
142
+ optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
143
+ elif optimizer_str.lower() == 'rmsprop':
144
+ optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)
145
+ else:
146
+ optimizer = Adam(learning_rate=learning_rate)
147
+
148
+ model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
149
+ encoded_labels = to_categorical(labels, num_classes=num_classes)
150
+ model.fit(np.array(embeddings), encoded_labels, epochs=epochs, batch_size=batch_size, verbose=0)
151
+ return model
152
+
153
+ # --- RAG Question Answering ---
154
+ st.header("RAG Question Answering")
155
+ rag_model_name = st.selectbox("Select model for RAG:", ["gemini-pro"], index=0)
156
+ rag_embedding_model = st.selectbox("Select embedding model for RAG:", ["models/embedding-001"], index=0)
157
+ rag_context = st.text_area(
158
+ "Enter your context documents:",
159
+ "Relevant information to answer the question. Separate documents with newlines.",
160
+ height=150,
161
+ )
162
+ rag_question = st.text_area("Ask a question about the context:", "What is the main topic?", height=70)
163
+ rag_max_context_length = st.number_input("Maximum Context Length", min_value=100, max_value=2000, value=500, step=100)
164
+
165
+ if st.button("Answer with RAG"):
166
+ if not rag_context or not rag_question:
167
+ st.warning("Please provide both context and a question.")
168
+ else:
169
+ with st.spinner("Generating answer..."):
170
+ try:
171
+ # 1. Generate embeddings for the context
172
+ context_embeddings = generate_embeddings(rag_context.split('\n'), rag_embedding_model)
173
+ if not context_embeddings:
174
+ st.stop()
175
+
176
+ # 2. Generate embedding for the question
177
+ question_embedding = generate_embeddings([rag_question], rag_embedding_model)
178
+ if not question_embedding:
179
+ st.stop()
180
+
181
+ # 3. Calculate similarity scores
182
+ similarities = cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(context_embeddings))[0]
183
+
184
+ # 4. Find the most relevant document(s)
185
+ most_relevant_index = np.argmax(similarities)
186
+ relevant_context = rag_context.split('\n')[most_relevant_index]
187
+ if len(relevant_context) > rag_max_context_length:
188
+ relevant_context = relevant_context[:rag_max_context_length]
189
+
190
+ # 5. Construct the prompt
191
+ rag_prompt = f"Use the following context to answer the question: '{rag_question}'.\nContext: {relevant_context}"
192
+
193
+ # 6. Generate the answer
194
+ response = generate_with_retry(rag_prompt, rag_model_name, generation_config=genai.types.GenerationConfig())
195
+ if response:
196
+ display_response(response)
197
+ except Exception as e:
198
+ st.error(f"An error occurred: {e}")
199
+
200
+ # --- Text Similarity ---
201
+ st.header("Text Similarity")
202
+ similarity_embedding_model = st.selectbox("Select embedding model for similarity:", ["models/embedding-001"], index=0)
203
+ text1 = st.text_area("Enter text 1:", "This is the first sentence.", height=70)
204
+ text2 = st.text_area("Enter text 2:", "This is a similar sentence.", height=70)
205
+
206
+ if st.button("Calculate Similarity"):
207
+ if not text1 or not text2:
208
+ st.warning("Please provide both texts.")
209
+ else:
210
+ with st.spinner("Calculating similarity..."):
211
+ try:
212
+ embeddings = generate_embeddings([text1, text2], similarity_embedding_model)
213
+ if not embeddings:
214
+ st.stop()
215
+ similarity = calculate_similarity(embeddings[0], embeddings[1])
216
+ st.subheader("Cosine Similarity:")
217
+ st.write(similarity)
218
+ except Exception as e:
219
+ st.error(f"An error occurred: {e}")
220
+
221
+ # --- Neural Classification ---
222
+ st.header("Neural Classification with Embeddings")
223
+ classification_embedding_model = st.selectbox("Select embedding model for classification:", ["models/embedding-001"], index=0)
224
+ classification_data = st.text_area(
225
+ "Enter your training data (text, label pairs), separated by newlines. Example: text1,0\\ntext2,1",
226
+ "text1,0\ntext2,1\ntext3,0\ntext4,1",
227
+ height=150,
228
+ )
229
+ classification_prompt = st.text_area("Enter text to classify:", "This is a test text.", height=70)
230
+ num_epochs = st.number_input("Number of Epochs", min_value=1, max_value=200, value=10, step=1)
231
+ batch_size = st.number_input("Batch Size", min_value=1, max_value=128, value=32, step=1)
232
+ learning_rate = st.number_input("Learning Rate", min_value=0.0001, max_value=0.1, value=0.0001, step=0.0001, format="%.4f")
233
+ optimizer_str = st.selectbox("Optimizer", ['adam', 'sgd', 'rmsprop'], index=0)
234
+
235
+ def process_classification_data(data: str) -> Optional[tuple[List[str], List[int]]]:
236
+ """Processes the classification data string into lists of texts and labels."""
237
+ data_pairs = [line.split(',') for line in data.split('\n') if ',' in line]
238
+ if not data_pairs:
239
+ st.error("No valid data pairs found. Please ensure each line contains 'text,label'.")
240
+ return None
241
+ texts = []
242
+ labels = []
243
+ for i, pair in enumerate(data_pairs):
244
+ if len(pair) != 2:
245
+ st.error(f"Invalid data format in line {i + 1}: '{','.join(pair)}'. Expected 'text,label'.")
246
+ return None
247
+ text = pair[0].strip()
248
+ label_str = pair[1].strip()
249
+ try:
250
+ label = int(label_str)
251
+ texts.append(text)
252
+ labels.append(label)
253
+ except ValueError:
254
+ st.error(f"Invalid label value in line {i + 1}: '{label_str}'. Label must be an integer.")
255
+ return None
256
+ return texts, labels
257
+
258
+ if st.button("Classify"):
259
+ if not classification_data or not classification_prompt:
260
+ st.warning("Please provide training data and text to classify.")
261
+ else:
262
+ with st.spinner("Classifying..."):
263
+ try:
264
+ processed_data = process_classification_data(classification_data)
265
+ if not processed_data:
266
+ st.stop()
267
+ train_texts, train_labels = processed_data
268
+ num_classes = len(set(train_labels))
269
+
270
+ train_embeddings = generate_embeddings(train_texts, classification_embedding_model)
271
+ if not train_embeddings:
272
+ st.stop()
273
+
274
+ model = create_and_train_model(
275
+ train_embeddings, train_labels, num_classes, num_epochs, batch_size, learning_rate, optimizer_str
276
+ )
277
+
278
+ predict_embedding = generate_embeddings([classification_prompt], classification_embedding_model)
279
+ if not predict_embedding:
280
+ st.stop()
281
+
282
+ prediction = model.predict(np.array([predict_embedding]), verbose=0)
283
+ predicted_class = np.argmax(prediction[0])
284
+ st.subheader("Predicted Class:")
285
+ st.write(predicted_class)
286
+ st.subheader("Prediction Probabilities:")
287
+ st.write(prediction)
288
+
289
+ except Exception as e:
290
+ st.error(f"An error occurred: {e}")