kinely commited on
Commit
b878812
·
verified ·
1 Parent(s): 854fbf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -27
app.py CHANGED
@@ -3,7 +3,6 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer
3
  from sentence_transformers import SentenceTransformer
4
  import faiss
5
  import torch
6
- import numpy as np
7
  import wikipediaapi
8
 
9
  # Initialize Wikipedia API with a custom user-agent
@@ -26,26 +25,20 @@ def fetch_wikipedia_articles(titles):
26
  # Initialize SentenceTransformer for embeddings
27
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
28
 
29
- # List of Wikipedia articles to retrieve
30
  titles = [
31
- "Crypto",
32
  "Finance",
33
  "Technology",
34
  "Healthcare",
35
  "Education"
36
  ]
37
-
38
- # Fetch and create the corpus
39
- # st.write("Fetching Wikipedia articles...")
40
- st.write("")
41
  corpus = fetch_wikipedia_articles(titles)
42
 
43
  # Generate embeddings for the corpus
44
- # st.write("Generating embeddings...")
45
- st.write("")
46
  embeddings = embedder.encode(corpus, convert_to_tensor=True)
47
-
48
- # Convert embeddings to NumPy array
49
  embeddings_np = embeddings.cpu().numpy()
50
 
51
  # Initialize FAISS index and add embeddings
@@ -60,31 +53,28 @@ tokenizer = T5Tokenizer.from_pretrained(model_name)
60
  # Streamlit interface
61
  st.title("Humanized AI Text Generator")
62
 
63
- # Text input from the user (no character limit, with adjustable height)
64
- user_input = st.text_area("Enter your query here", height=200)
65
 
66
- # Button to generate text
67
  if st.button("Generate Humanized Text"):
68
- if user_input.strip(): # Ensure non-empty input
69
- # Convert user input to embedding for retrieval
70
  query_embedding = embedder.encode([user_input], convert_to_tensor=True)
71
-
72
- # Retrieve top 5 related documents from FAISS index
73
  _, top_k_indices = faiss_index.search(query_embedding.cpu().numpy(), k=5)
74
-
75
- # Retrieve documents based on top_k_indices
76
  def retrieve_documents(top_k_indices):
77
  return " ".join([corpus[i] for i in top_k_indices[0]])
78
 
79
  context = retrieve_documents(top_k_indices)
80
-
81
- # Concatenate query and context
82
  input_text = f"{user_input} {context}"
 
 
 
83
 
84
- # Tokenize input and generate output
85
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024) # Adjusted max_length for input
86
-
87
- # Generate output without truncation in the generate method
88
  outputs = model.generate(inputs.input_ids, max_length=2000, num_return_sequences=1)
89
 
90
  # Decode the generated text
@@ -93,4 +83,4 @@ if st.button("Generate Humanized Text"):
93
  # Display the generated text
94
  st.write(generated_text)
95
  else:
96
- st.write("Please enter a query.")
 
3
  from sentence_transformers import SentenceTransformer
4
  import faiss
5
  import torch
 
6
  import wikipediaapi
7
 
8
  # Initialize Wikipedia API with a custom user-agent
 
25
  # Initialize SentenceTransformer for embeddings
26
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
27
 
28
+ # Fetch and create the corpus
29
  titles = [
30
+ "Crypto",
31
  "Finance",
32
  "Technology",
33
  "Healthcare",
34
  "Education"
35
  ]
36
+ st.write("Fetching Wikipedia articles...")
 
 
 
37
  corpus = fetch_wikipedia_articles(titles)
38
 
39
  # Generate embeddings for the corpus
40
+ st.write("Generating embeddings...")
 
41
  embeddings = embedder.encode(corpus, convert_to_tensor=True)
 
 
42
  embeddings_np = embeddings.cpu().numpy()
43
 
44
  # Initialize FAISS index and add embeddings
 
53
  # Streamlit interface
54
  st.title("Humanized AI Text Generator")
55
 
56
+ # Input from the user
57
+ user_input = st.text_area("Enter your query here (e.g., about a country, concept, etc.)", height=200)
58
 
 
59
  if st.button("Generate Humanized Text"):
60
+ if user_input.strip():
61
+ # Retrieve context from FAISS based on user input embedding
62
  query_embedding = embedder.encode([user_input], convert_to_tensor=True)
 
 
63
  _, top_k_indices = faiss_index.search(query_embedding.cpu().numpy(), k=5)
64
+
65
+ # Retrieve documents based on FAISS top_k_indices
66
  def retrieve_documents(top_k_indices):
67
  return " ".join([corpus[i] for i in top_k_indices[0]])
68
 
69
  context = retrieve_documents(top_k_indices)
70
+
71
+ # Concatenate user input and context for model input
72
  input_text = f"{user_input} {context}"
73
+
74
+ # Tokenize input
75
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024)
76
 
77
+ # Generate output
 
 
 
78
  outputs = model.generate(inputs.input_ids, max_length=2000, num_return_sequences=1)
79
 
80
  # Decode the generated text
 
83
  # Display the generated text
84
  st.write(generated_text)
85
  else:
86
+ st.write("Please enter a valid query.")