dazai555 commited on
Commit
c68fdcf
·
verified ·
1 Parent(s): c1ac3c6

Update utils/model.py

Browse files
Files changed (1) hide show
  1. utils/model.py +81 -80
utils/model.py CHANGED
@@ -1,81 +1,82 @@
1
- import os
2
- import re
3
- from langchain.embeddings import HuggingFaceEmbeddings
4
- from langchain.llms import HuggingFaceHub
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain.vectorstores import FAISS
7
- from langchain.document_loaders import TextLoader
8
- from langchain.chains import RetrievalQA
9
-
10
- llm = HuggingFaceHub(
11
- repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
12
- model_kwargs={
13
- "temperature": 0.75,
14
- "max_length": 500,
15
- }
16
- )
17
-
18
- def get_links():
19
- with open("data/links.txt", "r", encoding="utf-8") as file:
20
- data = file.read()
21
-
22
- lines = data.strip().split('\n')
23
-
24
- places_template = {}
25
- for line in lines:
26
- parts = re.split(r':\s*', line, maxsplit=1)
27
- if len(parts) == 2:
28
- place = parts[0].strip()
29
- link = parts[1].strip()
30
- places_template[place] = link
31
-
32
- return places_template
33
-
34
- def find_places_and_links(text, places):
35
- results = {}
36
-
37
- for place, link in places.items():
38
- pattern = re.compile(fr'\b{place}\b', flags=re.IGNORECASE)
39
- matches = pattern.findall(text)
40
-
41
- if matches:
42
- results[place] = link
43
-
44
- return results
45
-
46
- reviews_file_path = "data/data.txt"
47
- with open(reviews_file_path, "r", encoding="utf-8") as file:
48
- reviews = file.read().splitlines()
49
-
50
- loader = TextLoader(reviews_file_path)
51
- pages = loader.load_and_split()
52
-
53
- text_splitter = RecursiveCharacterTextSplitter(
54
- chunk_size=511,
55
- chunk_overlap=100,
56
- separators=['\n\n', '\n', '(?<=\. )', ' ', '']
57
- )
58
- docs = text_splitter.split_documents(pages)
59
-
60
- embeddings = HuggingFaceEmbeddings()
61
- vectorstore = FAISS.from_documents(docs, embeddings)
62
- retriever = vectorstore.as_retriever()
63
- chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
64
-
65
- def get_response_from_model(question: str):
66
- generated_response = chain({"query": question})
67
- question_index = generated_response["result"].find("Question:") or generated_response["result"].find("Answer:") or generated_response["result"].find("Helpful Answer:") or generated_response["result"].find("Recommended Restaurant:")
68
- if question_index != -1:
69
- answer = generated_response["result"][:question_index].strip()
70
- else:
71
- answer = generated_response["result"]
72
-
73
- places = get_links()
74
- links = find_places_and_links(answer, places)
75
- if links:
76
- output_list = [f'Location:']
77
-
78
- for place, link in links.items():
79
- output_list.append(f'{place}: {link}')
80
-
 
81
  return answer + '\n'.join(output_list)
 
1
+ import os
2
+ import re
3
+ from langchain.embeddings import HuggingFaceEmbeddings
4
+ from langchain.llms import HuggingFaceHub
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.vectorstores import FAISS
7
+ from langchain.document_loaders import TextLoader
8
+ from langchain.chains import RetrievalQA
9
+
10
+ llm = HuggingFaceHub(
11
+ repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
12
+ model_kwargs={
13
+ "temperature": 0.75,
14
+ "max_length": 500,
15
+ }
16
+ )
17
+
18
+ def get_links():
19
+ with open("data/links.txt", "r", encoding="utf-8") as file:
20
+ data = file.read()
21
+
22
+ lines = data.strip().split('\n')
23
+
24
+ places_template = {}
25
+ for line in lines:
26
+ parts = re.split(r':\s*', line, maxsplit=1)
27
+ if len(parts) == 2:
28
+ place = parts[0].strip()
29
+ link = parts[1].strip()
30
+ places_template[place] = link
31
+
32
+ return places_template
33
+
34
+ def find_places_and_links(text, places):
35
+ results = {}
36
+
37
+ for place, link in places.items():
38
+ pattern = re.compile(fr'\b{place}\b', flags=re.IGNORECASE)
39
+ matches = pattern.findall(text)
40
+
41
+ if matches:
42
+ results[place] = link
43
+
44
+ return results
45
+
46
+ reviews_file_path = "data/data.txt"
47
+ with open(reviews_file_path, "r", encoding="utf-8") as file:
48
+ reviews = file.read().splitlines()
49
+
50
+ loader = TextLoader(reviews_file_path)
51
+ pages = loader.load_and_split()
52
+
53
+ text_splitter = RecursiveCharacterTextSplitter(
54
+ chunk_size=511,
55
+ chunk_overlap=100,
56
+ separators=['\n\n', '\n', '(?<=\. )', ' ', '']
57
+ )
58
+ docs = text_splitter.split_documents(pages)
59
+
60
+ embeddings = HuggingFaceEmbeddings()
61
+ vectorstore = FAISS.from_documents(docs, embeddings)
62
+ retriever = vectorstore.as_retriever()
63
+ chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
64
+
65
+ def get_response_from_model(question: str):
66
+ generated_response = chain({"query": question})
67
+ question_index = generated_response["result"].find("Question:") or generated_response["result"].find("Answer:") or generated_response["result"].find("Helpful Answer:") or generated_response["result"].find("Recommended Restaurant:")
68
+ if question_index != -1:
69
+ answer = generated_response["result"][:question_index].strip()
70
+ else:
71
+ answer = generated_response["result"]
72
+
73
+ places = get_links()
74
+ links = find_places_and_links(answer, places)
75
+ output_list = []
76
+ if links:
77
+ output_list = [f'Location:']
78
+
79
+ for place, link in links.items():
80
+ output_list.append(f'{place}: {link}')
81
+
82
  return answer + '\n'.join(output_list)