Spaces:
Build error
Build error
| import json | |
| import threading | |
| import time | |
| import faiss | |
| import gradio | |
| import numpy | |
| import pandas | |
| import sentence_transformers | |
| import spaces | |
| import transformers | |
| # Constants | |
| GREETING = ( | |
| "Howdy! " | |
| "I'm an AI agent that uses [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the [Design Research Collective](https://cmudrc.github.io/). " | |
| "And the best part is that I always try to cite my sources! " | |
| "I still make some mistakes though. " | |
| "What can I tell you about today?" | |
| ) | |
| EXAMPLE_QUERIES = [ | |
| "Tell me about new research at the intersection of additive manufacturing and machine learning.", | |
| "What is a physics-informed neural network and what can it be used for?", | |
| "What can agent-based models do about climate change?", | |
| "What's the difference between a markov chain and a hidden markov model?", | |
| "What are the latest advancements in reinforcement learning?", | |
| "What is known about different modes for human-AI teaming?", | |
| ] | |
| EMBEDDING_MODEL_NAME = "allenai-specter" | |
| LLM_MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct" | |
| PUBLICATIONS_TO_RETRIEVE = 5 | |
| PARQUET_URL = "hf://datasets/ccm/publications/data/train-00000-of-00001.parquet" | |
| # Load the dataset and convert to pandas | |
| data = pandas.read_parquet(PARQUET_URL) | |
| # Filter out any publications without an abstract | |
| abstract_is_null = [ | |
| '"abstract": null' in json.dumps(bibdict) for bibdict in data["bib_dict"].values | |
| ] | |
| data = data[~pandas.Series(abstract_is_null)] | |
| data.reset_index(inplace=True) | |
| # Load the model for later use in embeddings | |
| model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| # Create an LLM pipeline that we can send queries to | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME, trust_remote_code=True) | |
| streamer = transformers.TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| chatmodel = transformers.AutoModelForCausalLM.from_pretrained( | |
| LLM_MODEL_NAME, device_map="auto", torch_dtype="auto", trust_remote_code=True | |
| ) | |
| # Create a FAISS index for fast similarity search | |
| metric = faiss.METRIC_INNER_PRODUCT | |
| vectors = numpy.stack(data["embedding"].tolist(), axis=0) | |
| index = faiss.IndexFlatL2(len(data["embedding"][0])) | |
| index.metric_type = metric | |
| faiss.normalize_L2(vectors) | |
| index.train(vectors) | |
| index.add(vectors) | |
| def preprocess(query: str, k: int) -> tuple[str, str]: | |
| """ | |
| Searches the dataset for the top k most relevant papers to the query and returns a prompt and references | |
| Args: | |
| query (str): The user's query | |
| k (int): The number of results to return | |
| Returns: | |
| tuple[str, str]: A tuple containing the prompt and references | |
| """ | |
| encoded_query = numpy.expand_dims(model.encode(query), axis=0) | |
| faiss.normalize_L2(encoded_query) | |
| D, I = index.search(encoded_query, k) | |
| top_five = data.loc[I[0]] | |
| prompt = ( | |
| "You are an AI assistant who delights in helping people learn about research from the Design Research Collective, which is a research lab at Carnegie Mellon University led by Professor Chris McComb. " | |
| "Your main task is to provide a concise ANSWER to the USER_QUERY that includes as many of the RESEARCH_ABSTRACTS as possible. " | |
| "The RESEARCH_ABSTRACTS are provided in the `.bibtex` format. Your ANSWER should contain citations to the RESEARCH_ABSTRACTS using (AUTHOR, YEAR) format. " | |
| "DO NOT list references at the end of the answer.\n\n" | |
| "RESEARCH_ABSTRACTS:\n```bibtex\n{{ABSTRACTS_GO_HERE}}\n```\n\n" | |
| "USER_GUERY:\n{{QUERY_GOES_HERE}}\n\n" | |
| "ANSWER:\n" | |
| ) | |
| references = [] | |
| research_abstracts = "" | |
| for i in range(k): | |
| year = str(int(top_five["bib_dict"].values[i]["pub_year"])) | |
| abstract = top_five["bib_dict"].values[i]["abstract"] | |
| url = "https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i] | |
| title = top_five["bib_dict"].values[i]["title"] | |
| last_names = [ | |
| author.split(" ")[-1] | |
| for author in top_five["bib_dict"] | |
| .values[i]["author"] | |
| .split(" and ") | |
| ] | |
| authors = ", ".join( | |
| last_names | |
| ) | |
| first_authors_last_name = last_names[0] | |
| research_abstracts += top_five["bibtex"].values[i] + "\n" | |
| references.append(f"<a href=\"{url}\">{first_authors_last_name} {year}</a>") | |
| prompt = prompt.replace("{{ABSTRACTS_GO_HERE}}", research_abstracts) | |
| prompt = prompt.replace("{{QUERY_GOES_HERE}}", query) | |
| print(prompt) | |
| return prompt, "; ".join(references) | |
| def reply(message: str, history: list[str]) -> str: | |
| """ | |
| This function is responsible for crafting a response | |
| Args: | |
| message (str): The user's message | |
| history (list[str]): The conversation history | |
| Returns: | |
| str: The AI's response | |
| """ | |
| # Apply preprocessing | |
| message, bypass = preprocess(message, PUBLICATIONS_TO_RETRIEVE) | |
| # This is some handling that is applied to the history variable to put it in a good format | |
| history_transformer_format = [ | |
| {"role": role, "content": message_pair[idx]} | |
| for message_pair in history | |
| for idx, role in enumerate(["user", "assistant"]) | |
| if message_pair[idx] is not None | |
| ] + [{"role": "user", "content": message}] | |
| # Stream a response from pipe | |
| text = tokenizer.apply_chat_template( | |
| history_transformer_format, tokenize=False, add_generation_prompt=True | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0") | |
| generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512) | |
| t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs) | |
| t.start() | |
| partial_message = "" | |
| for new_token in streamer: | |
| if new_token != "<": | |
| partial_message += new_token | |
| time.sleep(0.01) | |
| yield partial_message | |
| yield partial_message + "\n\n" + bypass | |
| # Create and run the gradio interface | |
| gradio.ChatInterface( | |
| reply, | |
| examples=EXAMPLE_QUERIES, | |
| chatbot=gradio.Chatbot( | |
| show_label=False, | |
| show_share_button=False, | |
| show_copy_button=False, | |
| value=[[None, GREETING]], | |
| avatar_images=[ | |
| "https://cdn.dribbble.com/users/316121/screenshots/2333676/11-04_scotty-plaid_dribbble.png", | |
| "https://media.thetab.com/blogs.dir/90/files/2021/06/screenshot-2021-06-10-at-110730-1024x537.png", | |
| ], | |
| height="60vh", | |
| bubble_full_width=False, | |
| ), | |
| retry_btn=None, | |
| undo_btn=None, | |
| clear_btn=None, | |
| theme=gradio.themes.Default( | |
| font=[gradio.themes.GoogleFont("Zilla Slab")] | |
| ) | |
| ).launch(debug=True) | |