kinely commited on
Commit
505df3c
·
verified ·
1 Parent(s): 3a6d9c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -13
app.py CHANGED
@@ -1,44 +1,48 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5ForConditionalGeneration
 
3
  import faiss
4
  import numpy as np
5
 
6
- # Load model and tokenizer for sentence transformers
7
- tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
8
- model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
9
 
10
- # Prepare dataset (Wikipedia dataset can be used)
11
  corpus = ["Article text 1", "Article text 2", "Article text 3"]
12
 
13
- # Tokenize and encode
14
- encoded_texts = [model(**tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)).last_hidden_state.mean(1).detach().numpy() for text in corpus]
15
 
16
  # Create FAISS index
17
- dimension = encoded_texts[0].shape[1]
18
  index = faiss.IndexFlatL2(dimension)
19
- index.add(np.vstack(encoded_texts))
20
 
 
21
  def retrieve(query, k=5):
22
- query_vector = model(**tokenizer(query, return_tensors='pt', truncation=True, max_length=512)).last_hidden_state.mean(1).detach().numpy()
23
  distances, indices = index.search(query_vector, k)
24
  return [corpus[i] for i in indices[0]]
25
 
 
26
  def generate_response(query):
27
  retrieved_docs = retrieve(query)
28
  context = " ".join(retrieved_docs)
29
 
30
- # Use the retrieved context to generate a humanized response
31
  flan_t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
32
  flan_t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
33
 
 
34
  input_text = f"Generate a human-like response: {query}. Context: {context}"
35
  input_ids = flan_t5_tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).input_ids
36
 
37
- # Generate text with length constraint
38
  generated_ids = flan_t5_model.generate(input_ids, max_length=1500)
39
  response = flan_t5_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
40
  return response
41
 
 
42
  def trim_to_word_limit(text, word_limit=1500):
43
  words = text.split()
44
  if len(words) > word_limit:
@@ -60,4 +64,4 @@ if st.button("Generate"):
60
  st.write(response)
61
 
62
  # Additional info or about section
63
- st.write("This app uses FAISS, sentence-transformers, and FLAN-T5 to generate contextually relevant human-like responses.")
 
1
  import streamlit as st
2
+ from sentence_transformers import SentenceTransformer
3
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
4
  import faiss
5
  import numpy as np
6
 
7
+ # Load SentenceTransformer model
8
+ model = SentenceTransformer('all-MiniLM-L6-v2')
 
9
 
10
+ # Prepare dataset (Wikipedia dataset or any other dataset can be used)
11
  corpus = ["Article text 1", "Article text 2", "Article text 3"]
12
 
13
+ # Encode the corpus using the sentence-transformers model
14
+ encoded_texts = model.encode(corpus, convert_to_numpy=True)
15
 
16
  # Create FAISS index
17
+ dimension = encoded_texts.shape[1]
18
  index = faiss.IndexFlatL2(dimension)
19
+ index.add(encoded_texts)
20
 
21
+ # Function to retrieve top-k relevant documents from the corpus
22
  def retrieve(query, k=5):
23
+ query_vector = model.encode([query], convert_to_numpy=True)
24
  distances, indices = index.search(query_vector, k)
25
  return [corpus[i] for i in indices[0]]
26
 
27
+ # Function to generate a human-like response using the FLAN-T5 model
28
  def generate_response(query):
29
  retrieved_docs = retrieve(query)
30
  context = " ".join(retrieved_docs)
31
 
32
+ # Load the FLAN-T5 model and tokenizer
33
  flan_t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
34
  flan_t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base")
35
 
36
+ # Format the input for the model
37
  input_text = f"Generate a human-like response: {query}. Context: {context}"
38
  input_ids = flan_t5_tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).input_ids
39
 
40
+ # Generate text response with a length constraint
41
  generated_ids = flan_t5_model.generate(input_ids, max_length=1500)
42
  response = flan_t5_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
43
  return response
44
 
45
+ # Function to trim the generated text to a word limit
46
  def trim_to_word_limit(text, word_limit=1500):
47
  words = text.split()
48
  if len(words) > word_limit:
 
64
  st.write(response)
65
 
66
  # Additional info or about section
67
+ st.write("This app uses FAISS, SentenceTransformers, and FLAN-T5 to generate contextually relevant human-like responses.")