Spaces:
Sleeping
Sleeping
| import time | |
| import random | |
| import ujson as json | |
| from typing import List | |
| from dataclasses import dataclass | |
| import gradio as gr | |
| class Paper: | |
| paper_id: str | |
| title: str | |
| abstract: str | |
| authors: List[str] = None | |
| year: int = None | |
| doi: str = None | |
| def load_database(filename): | |
| database = [] | |
| with open(filename, "r", encoding="utf-8") as f: | |
| for line in f: | |
| json_data = json.loads(line) | |
| data_point = Paper( | |
| paper_id=json_data["paper_id"], | |
| title=json_data["title"], | |
| abstract=json_data["abstract"], | |
| authors=json_data.get("authors", []), | |
| year=json_data.get("year", None), | |
| doi=json_data.get("doi", None) | |
| ) | |
| database.append(data_point) | |
| return database | |
| class S2ORCRAGPipeline: | |
| def __init__( | |
| self, | |
| s2orc_filename, | |
| model=lambda x: x, | |
| ): | |
| self.s2orc_filename = s2orc_filename | |
| self.database = load_database(s2orc_filename) | |
| self.model = model | |
| def retrieve_top_k( | |
| self, | |
| query: str, | |
| topk=5 | |
| ): | |
| # Fake | |
| random.seed(len(query) + topk) | |
| return random.sample(self.database, topk) | |
| # Real | |
| # TODO: DB-team | |
| def generate_response( | |
| self, | |
| query, | |
| retrieved_papers, | |
| ): | |
| # Fake | |
| response = f"{query}... わかった!こちらはあなたの質問に関連する論文です:\n" | |
| for paper in retrieved_papers: | |
| response += f"- {paper.title}: {paper.abstract}\n" | |
| response += "\nどう思いますか?\n" | |
| response = self.model(response) | |
| return response | |
| # Real | |
| # TODO: Generation-team | |
| def __call__( | |
| self, | |
| query | |
| ): | |
| # Firstly, retrieve papers from database | |
| retrieved_papers = self.retrieve_top_k(query, topk=3) | |
| # Secondly, generate response based on query and the retrieved papers | |
| response = self.generate_response(query, retrieved_papers) | |
| return response | |
| def slow_echo(self, message, history): | |
| output = self.__call__(query=message) | |
| for i in range(len(output)): | |
| time.sleep(0.001) | |
| yield output[: i + 1] | |
| if __name__ == "__main__": | |
| # load from S2ORC | |
| example_filename = "sample.jsonl" | |
| pipeline = S2ORCRAGPipeline( | |
| s2orc_filename=example_filename, | |
| model=lambda x: x | |
| ) | |
| initial_messages = [{"role": "assistant", "content": "こんにちは〜今日は何の論文を探したいですか?"}] | |
| demo = gr.ChatInterface( | |
| pipeline.slow_echo, | |
| chatbot=gr.Chatbot( | |
| value=initial_messages, | |
| type="messages", | |
| resizable=True, height=700, | |
| placeholder="こんにちは〜今日は何の論文を探したいですか?" | |
| ), | |
| type="messages", | |
| flagging_mode="manual", | |
| flagging_options=["Like", "Spam", "Inappropriate", "Other"], | |
| title="LLMC S2ORC 論文検索 (+RAG)", | |
| description="", | |
| save_history=True, | |
| examples=["こんにちは", "LLM関連の論文を探したい", "Find Suzuki's papers on graphene from 2019 to 2021 in Surface Science Journal."], | |
| ) | |
| demo.launch(debug=True, share=True) # Share=True is failed when using NII Network | |