goldrode commited on
Commit
5d7671d
·
verified ·
1 Parent(s): 1960cfb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -26
app.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import faiss
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
7
 
8
  # Load the knowledge base
9
  with open("knowledge_base.json", "r") as file:
@@ -29,41 +30,23 @@ llm = AutoModelForCausalLM.from_pretrained(llama_model_name, token=API_TOKEN)
29
  # Generate advice using RAG
30
  def generate_advice(extracted_data):
31
  try:
32
- # Ensure extracted_data is valid
33
- if not isinstance(extracted_data, list):
34
- raise ValueError("Input data must be a list of dictionaries.")
35
- if not all(isinstance(item, dict) for item in extracted_data):
36
- raise ValueError("Each item in input data must be a dictionary.")
37
-
38
  recommendations = []
39
 
40
  for item in extracted_data:
41
- # Validate required keys
42
  if not all(k in item for k in ["Component", "Status"]):
43
  raise ValueError("Each input item must have 'Component' and 'Status' keys.")
44
-
45
  # Prepare the query string
46
  query = f"{item['Component']} {item['Status']}"
47
- print(f"Processing query: {query}") # Debug print
48
 
49
- # Generate query embedding and reshape
50
  query_embedding = embedding_model.encode([query])
51
  query_embedding = np.array(query_embedding, dtype="float32").reshape(1, -1)
52
 
53
- # Debugging embedding dimensions
54
- print(f"Query Embedding Shape: {query_embedding.shape}, FAISS Index Dim: {index.d}")
55
-
56
- # Validate embedding dimensions
57
- if query_embedding.shape[1] != index.d:
58
- raise ValueError(
59
- f"Embedding dimension mismatch: Query ({query_embedding.shape[1]}), Index ({index.d})"
60
- )
61
-
62
  # Search for the closest match in FAISS
63
  _, idx = index.search(query_embedding, 1)
64
- print(f"FAISS Index: {idx}, Best Match Raw: {kb[idx[0][0]]}")
65
-
66
- # Retrieve the closest match
67
  best_match = kb[idx[0][0]]
68
 
69
  # Prepare the LLM prompt
@@ -78,21 +61,26 @@ def generate_advice(extracted_data):
78
  Provide additional insights or recommendations.
79
  """
80
 
81
- # Generate advice using LLaMA model
82
  message_yours = [
83
  {"role": "system", "content": role},
84
  {"role": "user", "content": prompt},
85
  ]
86
 
 
87
  input_text_with_your_role = tokenizer.apply_chat_template(
88
  message_yours,
89
- tokenize=False,
90
  add_generation_prompt=True,
91
  return_tensors="pt",
92
  )
93
 
 
 
 
 
94
  output = llm.generate(
95
- input_ids=input_text_with_your_role,
96
  max_length=150,
97
  num_return_sequences=1
98
  )
@@ -107,7 +95,7 @@ def generate_advice(extracted_data):
107
  return recommendations
108
 
109
  except Exception as e:
110
- print(f"Error occurred: {str(e)}") # Debugging error
111
  return [{"error": f"Exception occurred: {str(e)}"}]
112
 
113
  # Gradio app with LLM integration
 
4
  import faiss
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ import torch
8
 
9
  # Load the knowledge base
10
  with open("knowledge_base.json", "r") as file:
 
30
  # Generate advice using RAG
31
  def generate_advice(extracted_data):
32
  try:
 
 
 
 
 
 
33
  recommendations = []
34
 
35
  for item in extracted_data:
36
+ # Validate input keys
37
  if not all(k in item for k in ["Component", "Status"]):
38
  raise ValueError("Each input item must have 'Component' and 'Status' keys.")
39
+
40
  # Prepare the query string
41
  query = f"{item['Component']} {item['Status']}"
42
+ print(f"Processing query: {query}")
43
 
44
+ # Generate query embedding
45
  query_embedding = embedding_model.encode([query])
46
  query_embedding = np.array(query_embedding, dtype="float32").reshape(1, -1)
47
 
 
 
 
 
 
 
 
 
 
48
  # Search for the closest match in FAISS
49
  _, idx = index.search(query_embedding, 1)
 
 
 
50
  best_match = kb[idx[0][0]]
51
 
52
  # Prepare the LLM prompt
 
61
  Provide additional insights or recommendations.
62
  """
63
 
64
+ # Tokenize input properly for LLaMA
65
  message_yours = [
66
  {"role": "system", "content": role},
67
  {"role": "user", "content": prompt},
68
  ]
69
 
70
+ # Properly tokenize to return a PyTorch tensor
71
  input_text_with_your_role = tokenizer.apply_chat_template(
72
  message_yours,
73
+ tokenize=True, # Must tokenize to return input_ids
74
  add_generation_prompt=True,
75
  return_tensors="pt",
76
  )
77
 
78
+ # Move tensor to appropriate device (CPU/GPU)
79
+ input_text_with_your_role = input_text_with_your_role.to(torch.device("cpu"))
80
+
81
+ # Generate advice
82
  output = llm.generate(
83
+ input_ids=input_text_with_your_role["input_ids"],
84
  max_length=150,
85
  num_return_sequences=1
86
  )
 
95
  return recommendations
96
 
97
  except Exception as e:
98
+ print(f"Error occurred: {str(e)}")
99
  return [{"error": f"Exception occurred: {str(e)}"}]
100
 
101
  # Gradio app with LLM integration