Spaces:
Runtime error
Runtime error
| import threading # to allow streaming response | |
| import time # to pave the delivery of the message | |
| import datasets # for loading RAG database | |
| import faiss # to create a search index | |
| import gradio # for the interface | |
| import numpy # to work with vectors | |
| import sentence_transformers # to load an embedding model | |
| import spaces # for GPU | |
| import transformers # to load an LLM | |
| # The greeting supplied by the agent when it starts | |
| GREETING = ( | |
| "Howdy! I'm an AI agent that uses [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) " | |
| "to answer questions about research published at [ASME IDETC](https://asmedigitalcollection.asme.org/IDETC-CIE) within the last 10 years or so. " | |
| "I always try to cite my sources, but sometimes things get a little weird. " | |
| "What can I tell you about today?" | |
| ) | |
| # Example queries supplied in the interface | |
| EXAMPLE_QUERIES = [ | |
| "What's the difference between a markov chain and a hidden markov model?", | |
| "What can you tell me about analytical target cascading?", | |
| "What is known about different modes for human-AI teaming?", | |
| "What are some examples of opportunistic versus restrictive design for additive manufacturing? Format your answer as a table with two columns (opportunistic, restrictive)." | |
| ] | |
| # The embedding model used | |
| EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" | |
| # The conversational model used | |
| LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct" | |
| # Load the dataset and convert to pandas | |
| data = datasets.load_dataset("ccm/rag-idetc")["train"].to_pandas() | |
| # Load the model for later use in embeddings | |
| embedding_model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| # Create an LLM pipeline that we can send queries to | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME) | |
| streamer = transformers.TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| chat_model = transformers.AutoModelForCausalLM.from_pretrained( | |
| LLM_MODEL_NAME, torch_dtype="auto", device_map="auto" | |
| ) | |
| # Create a FAISS index for fast similarity search | |
| vectors = numpy.stack(data["embedding"].tolist(), axis=0).astype("float32") | |
| excerpt_index = faiss.IndexFlatL2(len(data["embedding"][0])) | |
| excerpt_index.metric_type = faiss.METRIC_INNER_PRODUCT | |
| faiss.normalize_L2(vectors) | |
| excerpt_index.train(vectors) | |
| excerpt_index.add(vectors) | |
| def preprocess(query: str, k: int) -> tuple[str, str]: | |
| """ | |
| Searches the dataset for the top k most relevant papers to the query and returns a prompt and references | |
| Args: | |
| query (str): The user's query | |
| k (int): The number of results to return | |
| Returns: | |
| tuple[str, str]: A tuple containing the prompt and references | |
| """ | |
| encoded_query = numpy.expand_dims(embedding_model.encode(query), axis=0) | |
| faiss.normalize_L2(encoded_query) | |
| _, indices = excerpt_index.search(encoded_query, k) | |
| top_five = data.loc[indices[0]] | |
| print(top_five["text"].values) | |
| prompt = ( | |
| "You are an AI assistant who delights in helping people learn about research from the IDETC Conference." | |
| "Your main task is to provide an ANSWER to the USER_QUERY based on the RESEARCH_EXCERPTS." | |
| "Your ANSWER should be concise.\n\n" | |
| "RESEARCH_EXCERPTS:\n{{EXCERPTS_GO_HERE}}\n\n" | |
| "USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n" | |
| "ANSWER:\n" | |
| ) | |
| references = {} | |
| research_excerpts = "" | |
| for i in range(k): | |
| title = top_five["title"].values[i] | |
| id = top_five["id"].values[i] | |
| url = "https://doi.org/10.1115/" + id | |
| text = top_five["text"].values[i] | |
| research_excerpts += ( | |
| str(i + i) + ". This excerpt is from: '" + title + "':\n" + text + "\n" | |
| ) | |
| header = "[" + title.title() + "](" + url + ")\n" | |
| if header not in references.keys(): | |
| references[header] = [] | |
| references[header].append(text) | |
| prompt = prompt.replace("{{EXCERPTS_GO_HERE}}", research_excerpts) | |
| prompt = prompt.replace("{{QUERY_GOES_HERE}}", query) | |
| print(references) | |
| list_of_references = "\n".join( | |
| [ | |
| "### " | |
| + hyperlinked_title | |
| + "\n\n> ".join( | |
| [ | |
| "", | |
| *[ | |
| '"...' + excerpt + '..."' | |
| for excerpt in references[hyperlinked_title] | |
| ], | |
| ] | |
| ) | |
| for idx, hyperlinked_title in enumerate(references.keys()) | |
| ] | |
| ) | |
| return ( | |
| prompt, | |
| "\n\n<details><summary><h3>References</h3></summary>\n\n" | |
| + list_of_references | |
| + "\n\n</summary>", | |
| ) | |
| def postprocess(response: str, bypass_from_preprocessing: str) -> str: | |
| """ | |
| Applies a postprocessing step to the LLM's response before the user receives it | |
| Args: | |
| response (str): The LLM's response | |
| bypass_from_preprocessing (str): The bypass variable from the preprocessing step | |
| Returns: | |
| str: The postprocessed response | |
| """ | |
| return response + bypass_from_preprocessing | |
| def reply(message: str, history: list[str]) -> str: | |
| """ | |
| This function is responsible for crafting a response | |
| Args: | |
| message (str): The user's message | |
| history (list[str]): The conversation history | |
| Returns: | |
| str: The AI's response | |
| """ | |
| # Apply preprocessing | |
| message, bypass = preprocess(message, 10) | |
| # This is some handling that is applied to the history variable to put it in a good format | |
| history_transformer_format = [ | |
| {"role": role, "content": message_pair[idx]} | |
| for message_pair in history | |
| for idx, role in enumerate(["user", "assistant"]) | |
| if message_pair[idx] is not None | |
| ] + [{"role": "user", "content": message}] | |
| # Stream a response from pipe | |
| text = tokenizer.apply_chat_template( | |
| history_transformer_format, tokenize=False, add_generation_prompt=True | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0") | |
| generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512) | |
| t = threading.Thread(target=chat_model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| partial_message = "" | |
| for new_token in streamer: | |
| if new_token != "<": | |
| partial_message += new_token | |
| time.sleep(0.05) | |
| yield partial_message | |
| yield partial_message + bypass | |
| # Create and run the gradio interface | |
| gradio.ChatInterface( | |
| reply, | |
| examples=EXAMPLE_QUERIES, | |
| chatbot=gradio.Chatbot( | |
| avatar_images=( | |
| None, | |
| "https://event.asme.org/Events/media/library/images/IDETC-CIE/IDETC-Logo-Announcements.png?ext=.png", | |
| ), | |
| show_label=False, | |
| show_share_button=False, | |
| show_copy_button=False, | |
| value=[[None, GREETING]], | |
| height="60vh", | |
| bubble_full_width=False, | |
| ), | |
| retry_btn=None, | |
| undo_btn=None, | |
| clear_btn=None, | |
| ).launch(debug=True) | |