Dua Rajper commited on
Commit
0edb972
·
verified ·
1 Parent(s): 80cdccf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -0
app.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import google.generativeai as genai
4
+ from dotenv import load_dotenv
5
+ import json
6
+ import textwrap
7
+ import time
8
+ from typing import Any, List
9
+ import numpy as np
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+ import tensorflow as tf
12
+ from tensorflow.keras.models import Sequential
13
+ from tensorflow.keras.layers import Dense, Input
14
+ from tensorflow.keras.utils import to_categorical
15
+
16
+ # Load environment variables
17
+ load_dotenv()
18
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
19
+
20
+ # Configure Generative AI model
21
+ if GOOGLE_API_KEY:
22
+ genai.configure(api_key=GOOGLE_API_KEY)
23
+ model = genai.GenerativeModel('gemini-pro') # You can choose a suitable model.
24
+ else:
25
+ st.error(
26
+ "Google AI Studio API key not found. Please add it to your .env file. "
27
+ "You can obtain an API key from https://makersuite.google.com/."
28
+ )
29
+ st.stop()
30
+
31
+ st.title("Embeddings and Vector Search Demo")
32
+ st.subheader("Explore Embeddings and Vector Databases")
33
+
34
+ # Sidebar for explanations
35
+ with st.sidebar:
36
+ st.header("Embeddings and Vector Search")
37
+ st.markdown(
38
+ """
39
+ This app demonstrates how embeddings and vector databases can be used for various tasks.
40
+ """
41
+ )
42
+ st.subheader("Key Concepts:")
43
+ st.markdown(
44
+ """
45
+ - **Embeddings**: Numerical representations of text, capturing semantic meaning.
46
+ - **Vector Databases**: Databases optimized for storing and querying vectors.
47
+ - **Retrieval Augmented Generation (RAG)**: Combining retrieval with LLM generation.
48
+ - **Cosine Similarity**: A measure of similarity between two vectors.
49
+ """
50
+ )
51
+ st.subheader("Whitepaper Insights")
52
+ st.markdown(
53
+ """
54
+ - Efficient similarity search using vector indexes (e.g., ANN).
55
+ - Handling large datasets and scalability.
56
+ - Applications of embeddings: search, recommendation, classification.
57
+ """
58
+ )
59
+
60
+ # --- Helper Functions ---
61
+ def code_block(text: str, language: str = "text") -> None:
62
+ """Displays text as a formatted code block in Streamlit."""
63
+ st.markdown(f"```{language}\n{text}\n```", unsafe_allow_html=True)
64
+
65
+
66
+ def display_response(response: Any) -> None:
67
+ """Displays the model's response."""
68
+ if response and hasattr(response, "text"):
69
+ st.subheader("Generated Response:")
70
+ st.markdown(response.text)
71
+ else:
72
+ st.error("Failed to generate a response.")
73
+
74
+
75
+ def generate_embeddings(texts: List[str], model_name: str = 'models/embedding-001') -> List[List[float]]:
76
+ """Generates embeddings for a list of texts."""
77
+ try:
78
+ embedding_model = genai.EmbeddingModel(model_name)
79
+ embeddings = embedding_model.embed_content(texts=texts)
80
+ return [embedding.values for embedding in embeddings.embeddings] # Extract embedding values
81
+ except Exception as e:
82
+ st.error(f"Error generating embeddings: {e}")
83
+ return []
84
+
85
+
86
+
87
+ def generate_with_retry(prompt: str, model_name: str, generation_config: genai.types.GenerationConfig, max_retries: int = 3, delay: int = 5) -> Any:
88
+ """Generates content with retry logic."""
89
+ for i in range(max_retries):
90
+ try:
91
+ model = genai.GenerativeModel(model_name)
92
+ response = model.generate_content(prompt, generation_config=generation_config)
93
+ return response
94
+ except Exception as e:
95
+ error_message = str(e)
96
+ st.warning(f"Error during generation (attempt {i + 1}/{max_retries}): {error_message}")
97
+ if "404" in error_message and "not found" in error_message:
98
+ st.error(
99
+ f"Model '{model_name}' is not available or not supported. Please select a different model."
100
+ )
101
+ return None
102
+ elif i < max_retries - 1:
103
+ st.info(f"Retrying in {delay} seconds...")
104
+ time.sleep(delay)
105
+ else:
106
+ raise
107
+ raise Exception("Failed to generate content after maximum retries")
108
+
109
+
110
+
111
+ # --- RAG Question Answering ---
112
+ st.header("RAG Question Answering")
113
+ rag_model_name = st.selectbox("Select model for RAG:", ["gemini-pro"], index=0)
114
+ rag_context = st.text_area(
115
+ "Enter your context documents:",
116
+ "Relevant information to answer the question. Separate documents with newlines.",
117
+ height=150,
118
+ )
119
+ rag_question = st.text_area("Ask a question about the context:", "What is the main topic?", height=50)
120
+
121
+ if st.button("Answer with RAG"):
122
+ if not rag_context or not rag_question:
123
+ st.warning("Please provide both context and a question.")
124
+ else:
125
+ with st.spinner("Generating answer..."):
126
+ try:
127
+ # 1. Generate embeddings for the context
128
+ context_embeddings = generate_embeddings(rag_context.split('\n'))
129
+ if not context_embeddings:
130
+ st.stop()
131
+
132
+ # 2. Generate embedding for the question
133
+ question_embedding = generate_embeddings([rag_question])[0]
134
+
135
+ # 3. Calculate similarity scores
136
+ similarities = cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(context_embeddings))[0]
137
+
138
+ # 4. Find the most relevant document(s)
139
+ most_relevant_index = np.argmax(similarities)
140
+ relevant_context = rag_context.split('\n')[most_relevant_index]
141
+
142
+ # 5. Construct the prompt
143
+ rag_prompt = f"Use the following context to answer the question: '{rag_question}'.\nContext: {relevant_context}"
144
+
145
+ # 6. Generate the answer
146
+ response = generate_with_retry(rag_prompt, rag_model_name, generation_config=genai.types.GenerationConfig())
147
+ display_response(response)
148
+ except Exception as e:
149
+ st.error(f"An error occurred: {e}")
150
+
151
+
152
+
153
+ # --- Text Similarity ---
154
+ st.header("Text Similarity")
155
+ similarity_model_name = st.selectbox("Select model for similarity:", ["models/embedding-001"], index=0) # Use a model that supports embeddings
156
+ text1 = st.text_area("Enter text 1:", "This is the first sentence.", height=50)
157
+ text2 = st.text_area("Enter text 2:", "This is a similar sentence.", height=50)
158
+
159
+ if st.button("Calculate Similarity"):
160
+ if not text1 or not text2:
161
+ st.warning("Please provide both texts.")
162
+ else:
163
+ with st.spinner("Calculating similarity..."):
164
+ try:
165
+ # 1. Generate embeddings
166
+ embeddings = generate_embeddings([text1, text2], similarity_model_name)
167
+ if not embeddings:
168
+ st.stop()
169
+ # 2. Calculate cosine similarity
170
+ similarity = cosine_similarity([embeddings[0]], [embeddings[1]])[0][0]
171
+ st.subheader("Cosine Similarity:")
172
+ st.write(similarity)
173
+ except Exception as e:
174
+ st.error(f"An error occurred: {e}")
175
+
176
+
177
+
178
+ # --- Neural Classification ---
179
+ st.header("Neural Classification with Embeddings")
180
+ classification_model_name = st.selectbox("Select model for classification:", ["models/embedding-001"], index=0) #use embedding model
181
+ classification_data = st.text_area(
182
+ "Enter your training data (text, label pairs), separated by newlines. Example: text1,0\\ntext2,1",
183
+ "text1,0\ntext2,1\ntext3,0\ntext4,1",
184
+ height=150,
185
+ )
186
+ classification_prompt = st.text_area("Enter text to classify:", "This is a test text.", height=50)
187
+ num_epochs = st.number_input("Number of Epochs", min_value=1, max_value=100, value=10, step=1)
188
+
189
+
190
+ def create_and_train_model(embeddings: List[List[float]], labels: List[int], num_classes: int, epochs: int):
191
+ """Creates and trains a simple neural network for classification."""
192
+ model = Sequential([
193
+ Input(shape=(len(embeddings[0]),)), # Input shape is the embedding size
194
+ Dense(16, activation='relu'),
195
+ Dense(num_classes, activation='softmax') # Output layer with softmax
196
+ ])
197
+
198
+ model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
199
+ encoded_labels = to_categorical(labels, num_classes=num_classes) #one hot encode
200
+ model.fit(np.array(embeddings), encoded_labels, epochs=epochs, verbose=0) # Suppress training output
201
+ return model
202
+
203
+
204
+
205
+ if st.button("Classify"):
206
+ if not classification_data or not classification_prompt:
207
+ st.warning("Please provide training data and text to classify.")
208
+ else:
209
+ with st.spinner("Classifying..."):
210
+ try:
211
+ # 1. Process the training data
212
+ data_pairs = [line.split(',') for line in classification_data.split('\n') if ',' in line]
213
+ train_texts = [pair[0].strip() for pair in data_pairs]
214
+ train_labels = [int(pair[1].strip()) for pair in data_pairs]
215
+ num_classes = len(set(train_labels)) #number of classes
216
+
217
+ # 2. Generate embeddings for training data
218
+ train_embeddings = generate_embeddings(train_texts, classification_model_name)
219
+ if not train_embeddings:
220
+ st.stop()
221
+
222
+ # 3. Create and train the model
223
+ model = create_and_train_model(train_embeddings, train_labels, num_classes, num_epochs)
224
+
225
+ # 4. Generate embedding for the text to classify
226
+ predict_embedding = generate_embeddings([classification_prompt], classification_model_name)[0]
227
+
228
+ # 5. Make the prediction
229
+ prediction = model.predict(np.array([predict_embedding]), verbose=0)
230
+ predicted_class = np.argmax(prediction[0])
231
+ st.subheader("Predicted Class:")
232
+ st.write(predicted_class)
233
+ st.subheader("Prediction Probabilities:")
234
+ st.write(prediction)
235
+
236
+ except Exception as e:
237
+ st.error(f"An error occurred: {e}")
238
+