bluebear27 commited on
Commit
bdf2590
·
1 Parent(s): 24cdfe9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -3
app.py CHANGED
@@ -12,9 +12,84 @@ from langchain.vectorstores import FAISS
12
  from langchain.chains import ConversationalRetrievalChain
13
  from transformers import pipeline
14
 
15
- pipe = pipeline('sentiment-analysis')
16
- text = st.text_area('enter some text')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  if text:
19
- out = pipe(text)
20
  st.json(out)
 
12
  from langchain.chains import ConversationalRetrievalChain
13
  from transformers import pipeline
14
 
15
+ model_id = 'meta-llama/Llama-2-7b-chat-hf'
16
+ device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
17
+
18
+ bnb_config = transformers.BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_quant_type='nf4',
21
+ bnb_4bit_use_double_quant=True,
22
+ bnb_4bit_compute_dtype=bfloat16
23
+ )
24
+
25
+ hf_auth = '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'
26
+ model_config = transformers.AutoConfig.from_pretrained(
27
+ model_id,
28
+ use_auth_token=hf_auth
29
+ )
30
+
31
+ model = transformers.AutoModelForCausalLM.from_pretrained(
32
+ model_id,
33
+ trust_remote_code=True,
34
+ config=model_config,
35
+ quantization_config=bnb_config,
36
+ device_map='auto',
37
+ use_auth_token=hf_auth
38
+ )
39
+
40
+ model.eval()
41
+
42
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
43
+ model_id,
44
+ use_auth_token=hf_auth
45
+ )
46
+
47
+ stop_list = ['\nHuman:', '\n```\n']
48
+
49
+ stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
50
+ stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
51
+
52
+ class StopOnTokens(StoppingCriteria):
53
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
54
+ for stop_ids in stop_token_ids:
55
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
56
+ return True
57
+ return False
58
+
59
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
60
+
61
+ generate_text = transformers.pipeline(
62
+ model=model,
63
+ tokenizer=tokenizer,
64
+ return_full_text=True, # langchain expects the full text
65
+ task='text-generation',
66
+ # we pass model parameters here too
67
+ stopping_criteria=stopping_criteria, # without this model rambles during chat
68
+ temperature=0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
69
+ max_new_tokens=512, # max number of tokens to generate in the output
70
+ repetition_penalty=1.1 # without this output begins repeating
71
+ )
72
+
73
+ llm = HuggingFacePipeline(pipeline=generate_text)
74
+
75
+ web_links = ["https://stanford-cs324.github.io/winter2022/lectures/introduction/"]
76
+ loader = WebBaseLoader(web_links)
77
+ documents = loader.load()
78
+
79
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
80
+ all_splits = text_splitter.split_documents(documents)
81
+
82
+ model_name = "sentence-transformers/all-mpnet-base-v2"
83
+ model_kwargs = {"device": "cuda"}
84
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
85
+ vectorstore = FAISS.from_documents(all_splits, embeddings)
86
+
87
+ chain = ConversationalRetrievalChain.from_llm(llm, vectorstore.as_retriever(), return_source_documents=True)
88
+
89
+ chat_history = []
90
+
91
+ text = st.text_area('Enter your query: ')
92
 
93
  if text:
94
+ out = chain({"question": text, "chat_history": chat_history})
95
  st.json(out)