tenmenbot commited on
Commit
5c389b4
·
verified ·
1 Parent(s): de5cdf3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -31,8 +31,8 @@ for fname in os.listdir(articles_dir):
31
  index = faiss.IndexFlatL2(384)
32
  index.add(np.array(vectors))
33
 
34
- # 要約モデルの準備
35
- summarizer = pipeline("text-generation", model="rinna/japanese-gpt2-medium", tokenizer="rinna/japanese-gpt2-medium")
36
 
37
  # チャットボット関数
38
  def chat(query):
@@ -42,15 +42,13 @@ def chat(query):
42
  retrieved_titles = [titles[i] for i in I[0]]
43
  retrieved_urls = [urls[i] for i in I[0]]
44
 
45
- context = "\n\n".join(retrieved_texts)
46
- prompt = f"以下の情報を参考にして、質問「{query}」に対する自然でわかりやすい日本語の回答を300文字以内で作成してください。\n\n{context}\n\n回答:"
47
 
48
- generated = summarizer(prompt, max_new_tokens=100, do_sample=True)[0]["generated_text"]
49
- answer = generated.split("回答:")[-1].strip()
50
 
51
- # 関連記事URLを表示
52
- links = "\n".join([f"🔗 [{titles[i]}]({urls[i]})" for i in range(len(retrieved_titles))])
53
- return f"{answer}\n\n参考記事:\n{links}"
54
 
55
  # Gradio UI
56
  gr.Interface(fn=chat, inputs="text", outputs="text", title="ブログ記事から回答する転職チャットボット").launch()
 
31
  index = faiss.IndexFlatL2(384)
32
  index.add(np.array(vectors))
33
 
34
+ # 要約モデル(ken11/japanese-summary-model)
35
+ summarizer = pipeline("summarization", model="ken11/japanese-summary-model")
36
 
37
  # チャットボット関数
38
  def chat(query):
 
42
  retrieved_titles = [titles[i] for i in I[0]]
43
  retrieved_urls = [urls[i] for i in I[0]]
44
 
45
+ context = "\n\n".join(retrieved_texts)[:1000] # BARTは長文に弱いので最大1000文字に制限
46
+ prompt = f"{context}\n\n質問:{query}\nこの情報をもとに簡潔に回答してください。"
47
 
48
+ summary = summarizer(prompt, max_length=128, min_length=30, do_sample=False)[0]["summary_text"]
 
49
 
50
+ links = "\n".join([f"🔗 [{retrieved_titles[i]}]({retrieved_urls[i]})" for i in range(len(retrieved_titles))])
51
+ return f"{summary}\n\n参考記事:\n{links}"
 
52
 
53
  # Gradio UI
54
  gr.Interface(fn=chat, inputs="text", outputs="text", title="ブログ記事から回答する転職チャットボット").launch()