vicpada commited on
Commit
9f7e7ff
·
0 Parent(s):

Adding MVP

Browse files
Files changed (4) hide show
  1. .gitignore +7 -0
  2. README.md +11 -0
  3. app.py +212 -0
  4. requirements.txt +0 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .env
3
+ data/*
4
+ ~data/.gitkeep
5
+ venv
6
+ scripts
7
+ .ipynb*
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Super profes
3
+ emoji: 💡
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: "4.44.1"
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard Library Imports
2
+ import logging
3
+ import os
4
+
5
+ # Third-party Imports
6
+ from dotenv import load_dotenv
7
+ import chromadb
8
+ import logfire
9
+ import gradio as gr
10
+ from huggingface_hub import snapshot_download
11
+
12
+
13
+ # LlamaIndex (Formerly GPT Index) Imports
14
+ from llama_index.core import VectorStoreIndex
15
+ from llama_index.core.retrievers import VectorIndexRetriever
16
+ from llama_index.vector_stores.chroma import ChromaVectorStore
17
+ from llama_index.core.llms import MessageRole
18
+ from llama_index.core.memory import ChatSummaryMemoryBuffer
19
+ from llama_index.core.tools import RetrieverTool, ToolMetadata
20
+ from llama_index.agent.openai import OpenAIAgent
21
+ from llama_index.llms.openai import OpenAI
22
+ from llama_index.core import Settings
23
+ from llama_index.postprocessor.cohere_rerank import CohereRerank
24
+ from llama_index.embeddings.openai import OpenAIEmbedding
25
+
26
+ load_dotenv()
27
+
28
+ logfire.configure()
29
+
30
+ logger = logging.getLogger(__name__)
31
+ logging.basicConfig(level=logging.INFO)
32
+ logging.getLogger("httpx").setLevel(logging.WARNING)
33
+
34
+ PROMPT_SYSTEM_MESSAGE = """
35
+ You are an AI assistant expert responding to user queries with relevant information and context. Your expertise is to find the most relevant teacher for a student.
36
+ You take into account what the teacher studies are, any recommendations they may have and their score.
37
+ To find relevant information use the "Super_profe" tool. This tool returns the teachers information.
38
+ For each response always include the teacher's name, subjects, recommendations, and score and picture.
39
+ If the question is not related to finding a teacher, please provide more context or rephrase your question.
40
+ """
41
+
42
+ def download_knowledge_base_if_not_exists():
43
+ """Download the knowledge base from the Hugging Face Hub if it doesn't exist locally"""
44
+ if not os.path.exists("data/superprofe"):
45
+
46
+ logging.warning(
47
+ f"Vector database does not exist at 'data/', downloading from Hugging Face Hub..."
48
+ )
49
+
50
+ os.makedirs("data/superprofe")
51
+
52
+ snapshot_download(
53
+ repo_id="vicpada/SuperProfes",
54
+ local_dir="data/superprofe",
55
+ repo_type="dataset",
56
+ token=os.getenv("HF_TOKEN")
57
+ )
58
+ logging.info(f"Downloaded vector database to 'data/superprofe'")
59
+
60
+ def get_tools(db_collection="superprofe", cohere_api_key=None):
61
+ db = chromadb.PersistentClient(path=f"data/{db_collection}")
62
+ chroma_collection = db.get_or_create_collection(db_collection)
63
+ vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
64
+
65
+
66
+ logging.info(f"Vector store initialized with {chroma_collection.count()} documents.")
67
+
68
+ # Create the vector store index
69
+ logging.info("Creating vector store index...")
70
+
71
+ # Use the vector store to create an index
72
+
73
+ index = VectorStoreIndex.from_vector_store(
74
+ vector_store=vector_store,
75
+ show_progress=True,
76
+ use_async=True,
77
+ embed_model=Settings.embed_model
78
+ )
79
+
80
+ logging.info("Creating vector retriever...")
81
+
82
+ vector_retriever = VectorIndexRetriever(
83
+ index=index,
84
+ similarity_top_k=200,
85
+ embed_model=Settings.embed_model,
86
+ use_async=True,
87
+ verbose=True,
88
+ )
89
+
90
+ cohere_rerank3 = CohereRerank(top_n=5, model = 'rerank-english-v3.0', api_key = cohere_api_key)
91
+
92
+ logging.info("Creating tool...")
93
+
94
+ tools = [
95
+ RetrieverTool(
96
+ retriever=vector_retriever,
97
+ metadata=ToolMetadata(
98
+ name="Super_profe",
99
+ description="Useful for selecting the best teacher."
100
+ ),
101
+ node_postprocessors=[cohere_rerank3],
102
+ )
103
+ ]
104
+ return tools
105
+
106
+ def generate_completion(query, history, memory):
107
+ logging.info(f"User query: {query}")
108
+ logging.info(f"User history: {history}")
109
+ logging.info(f"User memory: {memory}")
110
+
111
+ openAI_api_key = os.getenv("OPENAI_API_KEY")
112
+ cohere_api_key = os.getenv("COHERE_API_KEY")
113
+
114
+ # Validate OpenAI API Key
115
+ if openAI_api_key is None or not openAI_api_key.startswith("sk-"):
116
+ logging.error("OpenAI API Key is not set or is invalid. Please provide a valid key.")
117
+ yield "Error: OpenAI API Key is not set or is invalid. Please provide a valid key."
118
+ return
119
+
120
+ llm = OpenAI(temperature=1, model="gpt-4o-mini", api_key=openAI_api_key)
121
+ client = llm._get_client()
122
+ logfire.instrument_openai(client)
123
+
124
+
125
+ # Validate Cohere API Key
126
+ if cohere_api_key is None or not cohere_api_key.strip():
127
+ logging.error("Cohere API Key is not set or is invalid. Please provide a valid key.")
128
+ yield "Error: Cohere API Key is not set or is invalid. Please provide a valid key."
129
+ return
130
+
131
+ with logfire.span(f"Running query: {query}"):
132
+
133
+ # Manage memory
134
+ chat_list = memory.get()
135
+ if len(chat_list) != 0:
136
+ user_index = [i for i, msg in enumerate(chat_list) if msg.role == MessageRole.USER]
137
+ if len(user_index) > len(history):
138
+ user_index_to_remove = user_index[len(history)]
139
+ chat_list = chat_list[:user_index_to_remove]
140
+ memory.set(chat_list)
141
+
142
+ logfire.info(f"chat_history: {len(memory.get())} {memory.get()}")
143
+ logfire.info(f"gradio_history: {len(history)} {history}")
144
+
145
+ # Create agent
146
+ tools = get_tools(db_collection="superprofe", cohere_api_key = cohere_api_key )
147
+
148
+ agent = OpenAIAgent.from_tools(
149
+ llm=llm,
150
+ memory=memory,
151
+ tools=tools,
152
+ system_prompt=PROMPT_SYSTEM_MESSAGE
153
+ )
154
+
155
+ # Generate answer
156
+ completion = agent.stream_chat(query)
157
+ answer_str = ""
158
+ for token in completion.response_gen:
159
+ answer_str += token
160
+ yield answer_str
161
+
162
+ logging.info(f"Source count: {len(completion.sources)}")
163
+ logging.info(f"Sources: {completion.sources}")
164
+
165
+ def launch_ui():
166
+
167
+ with gr.Blocks(
168
+ fill_height=True,
169
+ title="Superprofes 🤖",
170
+ analytics_enabled=True,
171
+ ) as demo:
172
+
173
+ memory_state = gr.State(
174
+ lambda: ChatSummaryMemoryBuffer.from_defaults(
175
+ token_limit=120000,
176
+ )
177
+ )
178
+ chatbot = gr.Chatbot(
179
+ scale=1,
180
+ placeholder="<strong>Superprofes 🤖: Encuentra al mejor profesor para tus necesidades</strong><br>",
181
+ show_label=False,
182
+ show_copy_button=True,
183
+ )
184
+
185
+ gr.ChatInterface(
186
+ fn=generate_completion,
187
+ chatbot=chatbot,
188
+ additional_inputs=[memory_state]
189
+ )
190
+
191
+ demo.queue(default_concurrency_limit=64)
192
+ demo.launch(debug=True, share=False) # Set share=True to share the app online
193
+
194
+ if __name__ == "__main__":
195
+ # Download the knowledge base if it doesn't exist
196
+ download_knowledge_base_if_not_exists()
197
+
198
+ # Set the GPU usage based on the environment variable
199
+ Settings.use_gpu = os.getenv("USE_GPU", "1") == "1"
200
+ if Settings.use_gpu:
201
+ logging.info("Using GPU for inference.")
202
+ else:
203
+ logging.info("Using CPU for inference.")
204
+
205
+ # Load the embedding model
206
+ Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small")
207
+ if Settings.embed_model is None:
208
+ logging.error("Embedding model could not be loaded. Exiting.")
209
+ exit(1)
210
+
211
+ # launch the UI
212
+ launch_ui()
requirements.txt ADDED
Binary file (986 Bytes). View file