interview-chat / app.py
tenmenbot's picture
Upload app.py
7eb64ea verified
import gradio as gr
import os
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from transformers import T5Tokenizer, T5ForConditionalGeneration
# θ¨˜δΊ‹γƒ•γ‚©γƒ«γƒ€θͺ­γΏθΎΌγΏ
articles_dir = "articles"
texts, titles, urls = [], [], []
model = SentenceTransformer("all-MiniLM-L6-v2")
# θ¨˜δΊ‹γ‚’θͺ­γΏθΎΌγ‚€
for fname in os.listdir(articles_dir):
with open(os.path.join(articles_dir, fname), "r", encoding="utf-8") as f:
content = f.read()
title_line = content.splitlines()[0].replace("γ‚Ώγ‚€γƒˆγƒ«οΌš", "").strip()
url_line = content.splitlines()[1].replace("URL:", "").strip()
body_text = "\n".join(content.splitlines()[3:])
titles.append(title_line)
urls.append(url_line)
texts.append(body_text)
vec = model.encode(body_text)
if 'vectors' not in locals():
vectors = [vec]
else:
vectors.append(vec)
index = faiss.IndexFlatL2(384)
index.add(np.array(vectors))
# T5要約ヒデル
tokenizer = T5Tokenizer.from_pretrained("sonoisa/t5-base-japanese")
t5_model = T5ForConditionalGeneration.from_pretrained("sonoisa/t5-base-japanese")
def generate_summary(text):
input_text = "summarize: " + text.replace("\n", " ")
input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
output_ids = t5_model.generate(input_ids, max_length=128, min_length=32, do_sample=False)
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return summary
# γƒγƒ£γƒƒγƒˆγƒœγƒƒγƒˆι–’ζ•°
def chat(query):
vec = model.encode([query])
_, I = index.search(np.array(vec), k=3)
retrieved_texts = [texts[i] for i in I[0]]
retrieved_titles = [titles[i] for i in I[0]]
retrieved_urls = [urls[i] for i in I[0]]
context = "\n\n".join(retrieved_texts)[:1000]
summary = generate_summary(context)
links = "\n".join([f"πŸ”— [{retrieved_titles[i]}]({retrieved_urls[i]})" for i in range(len(retrieved_titles))])
return f"{summary}\n\nε‚θ€ƒθ¨˜δΊ‹οΌš\n{links}"
# Gradio UI
gr.Interface(fn=chat, inputs="text", outputs="text", title="γƒ–γƒ­γ‚°θ¨˜δΊ‹γ‹γ‚‰ε›žη­”γ™γ‚‹θ»’θ·γƒγƒ£γƒƒγƒˆγƒœγƒƒγƒˆ").launch()