Nelly-43 commited on
Commit
fca1e2c
·
verified ·
1 Parent(s): e20e7ae

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +80 -80
inference.py CHANGED
@@ -1,81 +1,81 @@
1
- import os
2
- import torch
3
- import pandas as pd
4
- import transformers
5
- from pynvml import *
6
- import torch
7
- from langchain import hub
8
- from langchain_core.output_parsers import StrOutputParser
9
- from langchain_core.runnables import RunnablePassthrough
10
- from model_ret import load_model_and_pipeline
11
- from create_retriever import retriever_chroma
12
-
13
- # Model chain class
14
- class model_chain:
15
- model_name = ""
16
-
17
- def __init__(self,
18
- model_name_local,
19
- model_name_online="Llama",
20
- use_online=True,
21
- embedding_name="BAAI/bge-base-en-v1.5",
22
- splitter_type_dropdown="character",
23
- chunk_size_slider=512,
24
- chunk_overlap_slider=30,
25
- separator_textbox="\n",
26
- max_tokens_slider=2048) -> None:
27
- if os.path.exists(f"models//{model_name_local}") and len(os.listdir(f"models//{model_name_local}")):
28
- import gradio as gr
29
- gr.Info("Model *()* from online!!")
30
- self.model_name = model_name_local
31
- else:
32
- self.model_name = model_name_online
33
-
34
- self.tokenizer, self.model, self.llm = load_model_and_pipeline(self.model_name)
35
- # Creating the retriever
36
- # self.retriever = ensemble_retriever(embedding_name,
37
- # splitter_type=splitter_type_dropdown,
38
- # chunk_size=chunk_size_slider,
39
- # chunk_overlap=chunk_overlap_slider,
40
- # separator=separator_textbox,
41
- # max_tokens=max_tokens_slider)
42
- self.retriever = retriever_chroma(False, embedding_name, splitter_type_dropdown,
43
- chunk_size_slider, chunk_size_slider,
44
- separator_textbox, max_tokens_slider)
45
-
46
- # Defining the RAG chain
47
- prompt = hub.pull("rlm/rag-prompt")
48
- self.rag_chain = (
49
- {"context": self.retriever | self.format_docs, "question": RunnablePassthrough()}
50
- | prompt
51
- | self.llm
52
- | StrOutputParser()
53
- )
54
-
55
- # Helper function to format documents
56
- def format_docs(self, docs):
57
- return "\n\n".join(doc.page_content for doc in docs)
58
-
59
- # Retrieve RAG chain
60
- def rag_chain_ret(self):
61
- return self.rag_chain
62
-
63
- # Answer retrieval function
64
- def ans_ret(self, inp):
65
- if self.model_name == 'Flant5':
66
- my_question = "What is KUET?"
67
- data = self.retriever.invoke(inp)
68
- context = ""
69
- for x in data[:2]:
70
- context += (x.page_content) + "\n"
71
- inputs = f"""Please answer to this question using this context:\n{context}\n{my_question}"""
72
- inputs = self.tokenizer(inputs, return_tensors="pt")
73
- outputs = self.model.generate(**inputs)
74
- answer = self.tokenizer.decode(outputs[0])
75
- from textwrap import fill
76
- ans = fill(answer, width=100)
77
- return ans
78
-
79
- ans = self.rag_chain.invoke(inp)
80
- ans = ans.split("Answer:")[1]
81
  return ans
 
1
+ import os
2
+ import torch
3
+ import pandas as pd
4
+ import transformers
5
+ from pynvml import *
6
+ import torch
7
+ from langchain import hub
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.runnables import RunnablePassthrough
10
+ from model_ret import load_model_and_pipeline
11
+ from create_retriever import retriever_chroma
12
+
13
+ # Model chain class
14
+ class model_chain:
15
+ model_name = ""
16
+
17
+ def __init__(self,
18
+ model_name_local,
19
+ model_name_online="Llama",
20
+ use_online=True,
21
+ embedding_name="sentence-transformers/all-mpnet-base-v2",
22
+ splitter_type_dropdown="character",
23
+ chunk_size_slider=512,
24
+ chunk_overlap_slider=30,
25
+ separator_textbox="\n",
26
+ max_tokens_slider=2048) -> None:
27
+ if os.path.exists(f"models//{model_name_local}") and len(os.listdir(f"models//{model_name_local}")):
28
+ import gradio as gr
29
+ gr.Info("Model *()* from online!!")
30
+ self.model_name = model_name_local
31
+ else:
32
+ self.model_name = model_name_online
33
+
34
+ self.tokenizer, self.model, self.llm = load_model_and_pipeline(self.model_name)
35
+ # Creating the retriever
36
+ # self.retriever = ensemble_retriever(embedding_name,
37
+ # splitter_type=splitter_type_dropdown,
38
+ # chunk_size=chunk_size_slider,
39
+ # chunk_overlap=chunk_overlap_slider,
40
+ # separator=separator_textbox,
41
+ # max_tokens=max_tokens_slider)
42
+ self.retriever = retriever_chroma(False, embedding_name, splitter_type_dropdown,
43
+ chunk_size_slider, chunk_size_slider,
44
+ separator_textbox, max_tokens_slider)
45
+
46
+ # Defining the RAG chain
47
+ prompt = hub.pull("rlm/rag-prompt")
48
+ self.rag_chain = (
49
+ {"context": self.retriever | self.format_docs, "question": RunnablePassthrough()}
50
+ | prompt
51
+ | self.llm
52
+ | StrOutputParser()
53
+ )
54
+
55
+ # Helper function to format documents
56
+ def format_docs(self, docs):
57
+ return "\n\n".join(doc.page_content for doc in docs)
58
+
59
+ # Retrieve RAG chain
60
+ def rag_chain_ret(self):
61
+ return self.rag_chain
62
+
63
+ # Answer retrieval function
64
+ def ans_ret(self, inp):
65
+ if self.model_name == 'Flant5':
66
+ my_question = "What is KUET?"
67
+ data = self.retriever.invoke(inp)
68
+ context = ""
69
+ for x in data[:2]:
70
+ context += (x.page_content) + "\n"
71
+ inputs = f"""Please answer to this question using this context:\n{context}\n{my_question}"""
72
+ inputs = self.tokenizer(inputs, return_tensors="pt")
73
+ outputs = self.model.generate(**inputs)
74
+ answer = self.tokenizer.decode(outputs[0])
75
+ from textwrap import fill
76
+ ans = fill(answer, width=100)
77
+ return ans
78
+
79
+ ans = self.rag_chain.invoke(inp)
80
+ ans = ans.split("Answer:")[1]
81
  return ans