Spaces:
Build error
Build error
| # Required imports | |
| import json | |
| import time | |
| import os | |
| from sentence_transformers import SentenceTransformer | |
| from pinecone import Pinecone, ServerlessSpec | |
| from groq import Groq | |
| from tqdm.auto import tqdm | |
| import streamlit as st | |
| import re | |
| # Variables | |
| FILE_PATH = "anjibot_chunks.json" | |
| BATCH_SIZE = 384 | |
| INDEX_NAME = "groq-llama-3-rag" | |
| PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| DIMS = 768 | |
| encoder = SentenceTransformer('dwzhu/e5-base-4k') | |
| groq_client = Groq(api_key=GROQ_API_KEY) | |
| with open(FILE_PATH, 'r') as file: | |
| data= json.load(file) | |
| pc = Pinecone(api_key=PINECONE_API_KEY) | |
| spec = ServerlessSpec(cloud="aws", region='us-east-1') | |
| existing_indexes = [index_info["name"] for index_info in pc.list_indexes()] | |
| if INDEX_NAME not in existing_indexes: | |
| pc.create_index(INDEX_NAME, dimension=DIMS, metric='cosine', spec=spec) | |
| # Wait for the index to be initialized | |
| while not pc.describe_index(INDEX_NAME).status['ready']: | |
| time.sleep(1) | |
| index = pc.Index(INDEX_NAME) | |
| for i in tqdm(range(0, len(data['id']), BATCH_SIZE)): | |
| # Find end of batch | |
| i_end = min(len(data['id']), i + BATCH_SIZE) | |
| # Create batch | |
| batch = {k: v[i:i_end] for k, v in data.items()} | |
| # Create embeddings | |
| chunks = [f'{x["title"]}: {x["content"]}' for x in batch["metadata"]] | |
| embeds = encoder.encode(chunks) | |
| # Ensure correct length | |
| assert len(embeds) == (i_end - i) | |
| # Upsert to Pinecone | |
| to_upsert = list(zip(batch["id"], embeds, batch["metadata"])) | |
| index.upsert(vectors=to_upsert) | |
| def extract_course_code(text) -> list[str]: | |
| pattern = r'\b(?:geds?|stats?|maths?|cosc|seng|itgy)\s*\d{3}\b' | |
| match = re.findall(pattern, text, re.IGNORECASE) | |
| return match if match else None | |
| def get_docs(query: str, top_k: int) -> list[str]: | |
| course_code = extract_course_code(query) | |
| exact_matches = [] | |
| if course_code: | |
| course_code = [code.lower() for code in course_code] | |
| exact_matches = [ | |
| x['content'] for x in data['metadata'] | |
| if any(code in x['content'].lower() for code in course_code) | |
| ] | |
| remaining_slots = top_k - len(exact_matches) | |
| if remaining_slots > 0: | |
| xq = encoder.encode(query) | |
| res = index.query(vector=xq.tolist(), top_k=remaining_slots if exact_matches else top_k, include_metadata=True) | |
| embedding_matches = [x["metadata"]['content'] for x in res["matches"]] | |
| exact_matches.extend(embedding_matches) | |
| return exact_matches[:top_k] | |
| def get_response(query: str, docs: list[str]) -> str: | |
| system_message = ( | |
| "You are Anjibot, the AI course rep of 400 Level Computer Science department. You are always helpful, jovial, can be sarcastic but still sweet.\n" | |
| "Provide the answer to class-related queries using\n" | |
| "context provided below.\n" | |
| "If you don't the answer to the user's question based on your pretrained knowledge and the context provided, just direct the user to Anji the human course rep.\n" | |
| "Anji's phone number: 08145170886.\n\n" | |
| "CONTEXT:\n" | |
| "\n---\n".join(docs) | |
| ) | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": query} | |
| ] | |
| chat_response = groq_client.chat.completions.create( | |
| model="llama3-70b-8192", | |
| messages=messages | |
| ) | |
| return chat_response.choices[0].message.content | |
| def handle_query(user_query: str): | |
| docs = get_docs(user_query, top_k=5) | |
| response = get_response(user_query, docs=docs) | |
| for word in response.split(): | |
| yield word + " " | |
| time.sleep(0.05) | |
| def main(): | |
| st.title("Ask Anjibot 2.0") | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if prompt := st.chat_input("Ask me anything"): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| response = st.write_stream(handle_query(prompt)) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| if __name__ == "__main__": | |
| main() | |