Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,8 +6,8 @@ import os
|
|
| 6 |
from threading import Thread
|
| 7 |
import random
|
| 8 |
from datasets import load_dataset
|
| 9 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 10 |
import numpy as np
|
|
|
|
| 11 |
|
| 12 |
# GPU 메모리 관리
|
| 13 |
torch.cuda.empty_cache()
|
|
@@ -29,40 +29,19 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
|
| 29 |
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
|
| 30 |
print("Wikipedia dataset loaded:", wiki_dataset)
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
embeddings = hidden_states.mean(dim=1)
|
| 39 |
-
return embeddings
|
| 40 |
-
|
| 41 |
-
# 데이터셋의 질문들을 임베딩
|
| 42 |
-
print("임베딩 생성 시작...")
|
| 43 |
-
questions = wiki_dataset['train']['question'][:1000] # 처음 1000개만 사용 (테스트용)
|
| 44 |
-
question_embeddings = []
|
| 45 |
-
batch_size = 8 # 배치 사이즈 줄임
|
| 46 |
-
|
| 47 |
-
for i in range(0, len(questions), batch_size):
|
| 48 |
-
batch = questions[i:i+batch_size]
|
| 49 |
-
batch_embeddings = get_embeddings(batch, model, tokenizer)
|
| 50 |
-
question_embeddings.append(batch_embeddings.cpu())
|
| 51 |
-
if i % 100 == 0:
|
| 52 |
-
print(f"Processed {i}/{len(questions)} questions")
|
| 53 |
-
|
| 54 |
-
question_embeddings = torch.cat(question_embeddings, dim=0)
|
| 55 |
-
print("임베딩 생성 완료")
|
| 56 |
|
| 57 |
def find_relevant_context(query, top_k=3):
|
| 58 |
-
# 쿼리
|
| 59 |
-
|
| 60 |
|
| 61 |
# 코사인 유사도 계산
|
| 62 |
-
similarities =
|
| 63 |
-
query_embedding.cpu().numpy(),
|
| 64 |
-
question_embeddings.numpy()
|
| 65 |
-
)[0]
|
| 66 |
|
| 67 |
# 가장 유사한 질문들의 인덱스
|
| 68 |
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
|
@@ -70,11 +49,12 @@ def find_relevant_context(query, top_k=3):
|
|
| 70 |
# 관련 컨텍스트 추출
|
| 71 |
relevant_contexts = []
|
| 72 |
for idx in top_indices:
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
| 78 |
|
| 79 |
return relevant_contexts
|
| 80 |
|
|
@@ -83,11 +63,11 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
|
|
| 83 |
print(f'message is - {message}')
|
| 84 |
print(f'history is - {history}')
|
| 85 |
|
| 86 |
-
#
|
| 87 |
relevant_contexts = find_relevant_context(message)
|
| 88 |
context_prompt = "\n\n관련 참고 정보:\n"
|
| 89 |
for ctx in relevant_contexts:
|
| 90 |
-
context_prompt += f"Q: {ctx['question']}\nA: {ctx['answer']}\n\n"
|
| 91 |
|
| 92 |
# 대화 히스토리 구성
|
| 93 |
conversation = []
|
|
@@ -97,6 +77,7 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
|
|
| 97 |
{"role": "assistant", "content": answer}
|
| 98 |
])
|
| 99 |
|
|
|
|
| 100 |
# 컨텍스트를 포함한 최종 프롬프트 구성
|
| 101 |
final_message = context_prompt + "\n현재 질문: " + message
|
| 102 |
conversation.append({"role": "user", "content": final_message})
|
|
|
|
| 6 |
from threading import Thread
|
| 7 |
import random
|
| 8 |
from datasets import load_dataset
|
|
|
|
| 9 |
import numpy as np
|
| 10 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 11 |
|
| 12 |
# GPU 메모리 관리
|
| 13 |
torch.cuda.empty_cache()
|
|
|
|
| 29 |
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
|
| 30 |
print("Wikipedia dataset loaded:", wiki_dataset)
|
| 31 |
|
| 32 |
+
# TF-IDF 벡터라이저 초기화 및 학습
|
| 33 |
+
print("TF-IDF 벡터화 시작...")
|
| 34 |
+
questions = wiki_dataset['train']['question'][:10000] # 처음 10000개만 사용
|
| 35 |
+
vectorizer = TfidfVectorizer(max_features=1000)
|
| 36 |
+
question_vectors = vectorizer.fit_transform(questions)
|
| 37 |
+
print("TF-IDF 벡터화 완료")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def find_relevant_context(query, top_k=3):
|
| 40 |
+
# 쿼리 벡터화
|
| 41 |
+
query_vector = vectorizer.transform([query])
|
| 42 |
|
| 43 |
# 코사인 유사도 계산
|
| 44 |
+
similarities = (query_vector * question_vectors.T).toarray()[0]
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# 가장 유사한 질문들의 인덱스
|
| 47 |
top_indices = np.argsort(similarities)[-top_k:][::-1]
|
|
|
|
| 49 |
# 관련 컨텍스트 추출
|
| 50 |
relevant_contexts = []
|
| 51 |
for idx in top_indices:
|
| 52 |
+
if similarities[idx] > 0: # 유사도가 0보다 큰 경우만 포함
|
| 53 |
+
relevant_contexts.append({
|
| 54 |
+
'question': questions[idx],
|
| 55 |
+
'answer': wiki_dataset['train']['answer'][idx],
|
| 56 |
+
'similarity': similarities[idx]
|
| 57 |
+
})
|
| 58 |
|
| 59 |
return relevant_contexts
|
| 60 |
|
|
|
|
| 63 |
print(f'message is - {message}')
|
| 64 |
print(f'history is - {history}')
|
| 65 |
|
| 66 |
+
# 관련 컨텍스트 찾기
|
| 67 |
relevant_contexts = find_relevant_context(message)
|
| 68 |
context_prompt = "\n\n관련 참고 정보:\n"
|
| 69 |
for ctx in relevant_contexts:
|
| 70 |
+
context_prompt += f"Q: {ctx['question']}\nA: {ctx['answer']}\n유사도: {ctx['similarity']:.3f}\n\n"
|
| 71 |
|
| 72 |
# 대화 히스토리 구성
|
| 73 |
conversation = []
|
|
|
|
| 77 |
{"role": "assistant", "content": answer}
|
| 78 |
])
|
| 79 |
|
| 80 |
+
|
| 81 |
# 컨텍스트를 포함한 최종 프롬프트 구성
|
| 82 |
final_message = context_prompt + "\n현재 질문: " + message
|
| 83 |
conversation.append({"role": "user", "content": final_message})
|