tenmenbot commited on
Commit
de5cdf3
·
verified ·
1 Parent(s): d8896e7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -9
app.py CHANGED
@@ -3,30 +3,54 @@ import os
3
  import numpy as np
4
  import faiss
5
  from sentence_transformers import SentenceTransformer
 
6
 
7
  # 記事フォルダ読み込み
8
  articles_dir = "articles"
9
- texts, vectors = [], []
10
 
11
  model = SentenceTransformer("all-MiniLM-L6-v2")
12
 
 
13
  for fname in os.listdir(articles_dir):
14
  with open(os.path.join(articles_dir, fname), "r", encoding="utf-8") as f:
15
- text = f.read()
16
- texts.append(text)
17
- vec = model.encode(text)
18
- vectors.append(vec)
 
 
 
 
 
 
 
 
 
19
 
20
  index = faiss.IndexFlatL2(384)
21
  index.add(np.array(vectors))
22
 
 
 
 
23
  # チャットボット関数
24
  def chat(query):
25
  vec = model.encode([query])
26
  _, I = index.search(np.array(vec), k=3)
27
- context = "\n---\n".join([texts[i] for i in I[0]])
28
- prompt = f"以下の情報を参考に質問に答えてください。\n\n{context}\n\n質問: {query}\n回答:"
29
- return prompt
 
 
 
 
 
 
 
 
 
 
30
 
31
  # Gradio UI
32
- gr.Interface(fn=chat, inputs="text", outputs="text", title="ブログ記事チャットボット").launch()
 
3
  import numpy as np
4
  import faiss
5
  from sentence_transformers import SentenceTransformer
6
+ from transformers import pipeline
7
 
8
  # 記事フォルダ読み込み
9
  articles_dir = "articles"
10
+ texts, titles, urls = [], [], []
11
 
12
  model = SentenceTransformer("all-MiniLM-L6-v2")
13
 
14
+ # 記事を読み込む
15
  for fname in os.listdir(articles_dir):
16
  with open(os.path.join(articles_dir, fname), "r", encoding="utf-8") as f:
17
+ content = f.read()
18
+ title_line = content.splitlines()[0].replace("タイトル:", "").strip()
19
+ url_line = content.splitlines()[1].replace("URL:", "").strip()
20
+ body_text = "\n".join(content.splitlines()[3:])
21
+ titles.append(title_line)
22
+ urls.append(url_line)
23
+ texts.append(body_text)
24
+
25
+ vec = model.encode(body_text)
26
+ if 'vectors' not in locals():
27
+ vectors = [vec]
28
+ else:
29
+ vectors.append(vec)
30
 
31
  index = faiss.IndexFlatL2(384)
32
  index.add(np.array(vectors))
33
 
34
+ # 要約モデルの準備
35
+ summarizer = pipeline("text-generation", model="rinna/japanese-gpt2-medium", tokenizer="rinna/japanese-gpt2-medium")
36
+
37
  # チャットボット関数
38
  def chat(query):
39
  vec = model.encode([query])
40
  _, I = index.search(np.array(vec), k=3)
41
+ retrieved_texts = [texts[i] for i in I[0]]
42
+ retrieved_titles = [titles[i] for i in I[0]]
43
+ retrieved_urls = [urls[i] for i in I[0]]
44
+
45
+ context = "\n\n".join(retrieved_texts)
46
+ prompt = f"以下の情報を参考にして、質問「{query}」に対する自然でわかりやすい日本語の回答を300文字以内で作成してください。\n\n{context}\n\n回答:"
47
+
48
+ generated = summarizer(prompt, max_new_tokens=100, do_sample=True)[0]["generated_text"]
49
+ answer = generated.split("回答:")[-1].strip()
50
+
51
+ # 関連記事URLを表示
52
+ links = "\n".join([f"🔗 [{titles[i]}]({urls[i]})" for i in range(len(retrieved_titles))])
53
+ return f"{answer}\n\n参考記事:\n{links}"
54
 
55
  # Gradio UI
56
+ gr.Interface(fn=chat, inputs="text", outputs="text", title="ブログ記事から回答する転職チャットボット").launch()