ryoshimu commited on
Commit
de53089
·
1 Parent(s): faa1666
Files changed (4) hide show
  1. README.md +51 -0
  2. rag_system.py +103 -0
  3. requirements.txt +7 -0
  4. test.py +0 -1
README.md CHANGED
@@ -10,3 +10,54 @@ pinned: false
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
14
+ # RAG with Gemma
15
+
16
+ このプロジェクトは、Gemmaモデルを使用したRAG(Retrieval-Augmented Generation)システムの実装です。
17
+
18
+ ## 特徴
19
+
20
+ - Gemma-2b-itモデルを使用
21
+ - CPUで動作
22
+ - ChromaDBを使用したベクトルストア
23
+ - 日本語対応
24
+
25
+ ## セットアップ
26
+
27
+ 1. 必要なパッケージのインストール:
28
+ ```bash
29
+ pip install -r requirements.txt
30
+ ```
31
+
32
+ 2. モデルのダウンロード:
33
+ ```bash
34
+ python -c "from transformers import AutoTokenizer, AutoModelForCausalLM; AutoTokenizer.from_pretrained('google/gemma-2b-it'); AutoModelForCausalLM.from_pretrained('google/gemma-2b-it')"
35
+ ```
36
+
37
+ ## 使用方法
38
+
39
+ ```python
40
+ from rag_system import RAGSystem
41
+
42
+ # RAGシステムの初期化
43
+ rag = RAGSystem()
44
+
45
+ # ドキュメントの追加
46
+ documents = [
47
+ "ドキュメント1の内容",
48
+ "ドキュメント2の内容",
49
+ # ...
50
+ ]
51
+ rag.add_documents(documents)
52
+
53
+ # 質問と回答
54
+ question = "質問内容"
55
+ answer = rag.query(question)
56
+ print(f"回答: {answer}")
57
+ ```
58
+
59
+ ## 注意事項
60
+
61
+ - 初回実行時はモデルのダウンロードに時間がかかります
62
+ - CPUでの実行のため、生成に時間がかかる場合があります
63
+ - メモリ使用量に注意してください
rag_system.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ from sentence_transformers import SentenceTransformer
5
+ import chromadb
6
+ from chromadb.config import Settings
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ import torch
9
+
10
+ class RAGSystem:
11
+ def __init__(self, model_name: str = "google/gemma-2b-it"):
12
+ # トークナイザーとモデルの初期化
13
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ self.model = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ torch_dtype=torch.float32,
17
+ device_map="cpu"
18
+ )
19
+
20
+ # 埋め込みモデルの初期化
21
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
22
+
23
+ # ChromaDBの初期化
24
+ self.chroma_client = chromadb.Client(Settings(
25
+ chroma_db_impl="duckdb+parquet",
26
+ persist_directory="db"
27
+ ))
28
+ self.collection = self.chroma_client.get_or_create_collection("documents")
29
+
30
+ # テキスト分割器の初期化
31
+ self.text_splitter = RecursiveCharacterTextSplitter(
32
+ chunk_size=500,
33
+ chunk_overlap=50
34
+ )
35
+
36
+ def add_documents(self, documents: List[str]):
37
+ """ドキュメントをベクトルストアに追加"""
38
+ chunks = []
39
+ for doc in documents:
40
+ chunks.extend(self.text_splitter.split_text(doc))
41
+
42
+ embeddings = self.embedding_model.encode(chunks)
43
+
44
+ # ChromaDBに保存
45
+ self.collection.add(
46
+ embeddings=embeddings.tolist(),
47
+ documents=chunks,
48
+ ids=[f"doc_{i}" for i in range(len(chunks))]
49
+ )
50
+
51
+ def query(self, question: str, k: int = 3) -> str:
52
+ """質問に対する回答を生成"""
53
+ # 質問の埋め込みを取得
54
+ query_embedding = self.embedding_model.encode(question)
55
+
56
+ # 関連ドキュメントを検索
57
+ results = self.collection.query(
58
+ query_embeddings=[query_embedding.tolist()],
59
+ n_results=k
60
+ )
61
+
62
+ # コンテキストの構築
63
+ context = "\n".join(results['documents'][0])
64
+
65
+ # プロンプトの構築
66
+ prompt = f"""以下のコンテキストに基づいて質問に答えてください。
67
+
68
+ コンテキスト:
69
+ {context}
70
+
71
+ 質問: {question}
72
+
73
+ 回答:"""
74
+
75
+ # 回答の生成
76
+ inputs = self.tokenizer(prompt, return_tensors="pt")
77
+ outputs = self.model.generate(
78
+ **inputs,
79
+ max_length=1000,
80
+ num_return_sequences=1,
81
+ temperature=0.7
82
+ )
83
+
84
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+
86
+ if __name__ == "__main__":
87
+ # 使用例
88
+ rag = RAGSystem()
89
+
90
+ # サンプルドキュメントの追加
91
+ documents = [
92
+ "RAG(Retrieval-Augmented Generation)は、大規模言語モデルに外部知識を組み込む手法です。",
93
+ "RAGは、検索と生成を組み合わせることで、より正確な回答を生成することができます。",
94
+ "RAGの主な利点は、モデルの知識を超えた情報を提供できることです。"
95
+ ]
96
+
97
+ rag.add_documents(documents)
98
+
99
+ # 質問の例
100
+ question = "RAGとは何ですか?"
101
+ answer = rag.query(question)
102
+ print(f"質問: {question}")
103
+ print(f"回答: {answer}")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers>=4.38.0
2
+ sentence-transformers>=2.2.2
3
+ faiss-cpu>=1.7.4
4
+ langchain>=0.1.0
5
+ chromadb>=0.4.22
6
+ tqdm>=4.66.1
7
+ python-dotenv>=1.0.0
test.py DELETED
@@ -1 +0,0 @@
1
- print("Hello, World!")