allyyy commited on
Commit
adfc625
·
verified ·
1 Parent(s): 703fb27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -62
app.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import gradio as gr
2
  from sentence_transformers import SentenceTransformer, util
3
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
@@ -40,75 +46,18 @@ def load_and_preprocess_text(filename):
40
 
41
  segments = load_and_preprocess_text(filename)
42
 
43
- def find_relevant_segment(user_query, segments, similarity_threshold=0.5):
44
- """
45
- Find the most relevant text segment based on a user query.
46
- Parameters:
47
- - user_query (str): The user's query.
48
- - segments (list[str]): List of text segments to search within.
49
- - similarity_threshold (float): Minimum similarity required to consider a segment relevant.
50
- Returns:
51
- - str: The most relevant text segment.
52
- """
53
  try:
54
- query_embedding = retrieval_model.encode(user_query)
 
55
  segment_embeddings = retrieval_model.encode(segments)
56
  similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
57
  best_idx = similarities.argmax()
58
- if similarities[best_idx].item() >= similarity_threshold:
59
- return segments[best_idx]
60
- else:
61
- return "Sorry, I couldn't find a specific match. Here are some general tips to help you:"
62
  except Exception as e:
63
- print(f"Error finding relevant segment: {e}")
64
  return ""
65
 
66
- def clean_up_response(response, segment):
67
- """
68
- Clean up the generated response to ensure it is tidy and presentable.
69
- Parameters:
70
- - response (str): The initial response generated by the model.
71
- - segment (str): The segment used to generate the response.
72
- Returns:
73
- - str: A cleaned and formatted response.
74
- """
75
- sentences = response.split('.')
76
- cleaned_sentences = [sentence.strip() for sentence in sentences if sentence.strip() and sentence.strip().lower() not in segment.lower()]
77
- cleaned_response = '. '.join(cleaned_sentences).strip()
78
- if cleaned_response and not cleaned_response.endswith((".", "!", "?")):
79
- cleaned_response += "."
80
- return cleaned_response
81
-
82
-
83
- def generate_response_with_context(user_query, relevant_segment):
84
- """
85
- Generate a response based on a user query and a relevant segment.
86
-
87
- Parameters:
88
- - user_query (str): The user's query.
89
- - relevant_segment (str): A relevant fact or detail.
90
-
91
- Returns:
92
- - str: Formatted response incorporating the relevant segment.
93
- """
94
- try:
95
- # Prepare the prompt incorporating the relevant segment
96
- prompt = f"User: {user_query}\n\nAssistant: Here is some helpful information based on your topic: {relevant_segment}"
97
-
98
- # Calculate the maximum tokens allowed for the response
99
- max_tokens = len(tokenizer(prompt)['input_ids']) + 100
100
-
101
- # Generate the response using the model
102
- response = gpt_model(prompt, max_length=max_tokens, temperature=0.7)[0]['generated_text']
103
-
104
- # Clean up the response for better formatting and clarity
105
- return clean_up_response(response, relevant_segment)
106
-
107
- except Exception as e:
108
- print(f"Error generating response: {e}")
109
- return "I'm sorry, but there was an error generating your response. Please try again."
110
-
111
-
112
  def generate_response(user_query, relevant_segment):
113
  try:
114
  user_message = f"Here's the information on your request: {relevant_segment}"
 
1
+
2
+
3
+ Share
4
+
5
+
6
+ You said:
7
  import gradio as gr
8
  from sentence_transformers import SentenceTransformer, util
9
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
46
 
47
  segments = load_and_preprocess_text(filename)
48
 
49
+ def find_relevant_segment(user_query, segments):
 
 
 
 
 
 
 
 
 
50
  try:
51
+ lower_query = user_query.lower()
52
+ query_embedding = retrieval_model.encode(lower_query)
53
  segment_embeddings = retrieval_model.encode(segments)
54
  similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
55
  best_idx = similarities.argmax()
56
+ return segments[best_idx]
 
 
 
57
  except Exception as e:
58
+ print(f"Error in finding relevant segment: {e}")
59
  return ""
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def generate_response(user_query, relevant_segment):
62
  try:
63
  user_message = f"Here's the information on your request: {relevant_segment}"