Kakaarot commited on
Commit
dc2e285
·
verified ·
1 Parent(s): 272c653

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -54
app.py CHANGED
@@ -8,7 +8,7 @@ from retry import retry
8
  import os
9
  import json
10
 
11
- # Configure Gemini API with key from Hugging Face Secrets
12
  api_key = os.getenv("GEMINI_API_KEY")
13
  if not api_key:
14
  raise ValueError("GEMINI_API_KEY environment variable not set")
@@ -25,50 +25,44 @@ articles = [
25
  "Coral reefs face bleaching from rising ocean temperatures."
26
  ]
27
 
28
- # Generate embeddings
29
- embedding_model = "models/embedding-001" # Update to correct model name
30
  df = pd.DataFrame({"article": articles})
31
 
32
  @retry(tries=3, delay=2, backoff=2)
33
  def get_embedding(text):
34
  try:
35
  result = genai.embed_content(model=embedding_model, content=text, task_type="RETRIEVAL_DOCUMENT")
36
- # Extract embedding correctly based on API response structure
37
  embedding = result.embedding
38
  return embedding
39
  except Exception as e:
40
  print(f"Embedding error: {e}")
41
  raise
42
 
43
- # Generate all embeddings first
44
- all_embeddings = []
45
- for article in articles:
46
- try:
47
- embedding = get_embedding(article)
48
- all_embeddings.append(embedding)
49
- except Exception as e:
50
- print(f"Failed to embed article: {article[:30]}... Error: {e}")
51
- all_embeddings.append([0] * 768) # Default embedding dimension, adjust if needed
52
-
53
- df["embedding"] = all_embeddings
54
 
55
- # Initialize ChromaDB
56
  client_db = chromadb.Client()
57
  collection = client_db.get_or_create_collection("news_articles")
58
 
59
  # Clear existing data to avoid duplicates
60
  try:
61
  collection.delete(ids=[str(i) for i in range(len(df))])
62
- except:
63
  pass # Collection might be empty
64
 
65
- # Add documents to collection
66
  for idx, row in df.iterrows():
67
- collection.add(
68
- documents=[row["article"]],
69
- embeddings=[row["embedding"]],
70
- ids=[str(idx)]
71
- )
 
 
 
72
 
73
  # Semantic Search
74
  @retry(tries=3, delay=2, backoff=2)
@@ -76,6 +70,8 @@ def search_articles(query, top_k=3):
76
  try:
77
  query_embedding = get_embedding(query)
78
  results = collection.query(query_embeddings=[query_embedding], n_results=top_k)
 
 
79
  indices = [int(id) for id in results["ids"][0]]
80
  return df.iloc[indices]["article"].tolist()
81
  except Exception as e:
@@ -83,7 +79,7 @@ def search_articles(query, top_k=3):
83
  return []
84
 
85
  # RAG and Structured Q&A
86
- generation_model = genai.GenerativeModel("gemini-1.5-pro") # Verify model name
87
 
88
  @retry(tries=3, delay=2, backoff=2)
89
  def generate_response(query, articles, system_message):
@@ -91,13 +87,6 @@ def generate_response(query, articles, system_message):
91
  return "No relevant articles found.", json.dumps({"error": "No relevant articles found."})
92
 
93
  context = "\n".join(articles)
94
- safety_settings = [
95
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
96
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
97
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
98
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
99
- ]
100
-
101
  prompt = f"""
102
  {system_message}
103
  Based on the following articles, provide a concise summary (under 100 words) and a structured JSON response with 'question', 'answer', and 'source'. Use only the provided context.
@@ -111,41 +100,45 @@ def generate_response(query, articles, system_message):
111
  - Summary:
112
  - JSON:
113
  """
 
114
  try:
115
- generation_config = {
116
- "temperature": 0.7,
117
- "top_p": 0.95,
118
- "top_k": 40,
119
- "max_output_tokens": 1024,
120
- }
121
-
122
  response = generation_model.generate_content(
123
  prompt,
124
- generation_config=generation_config,
125
- safety_settings=safety_settings,
 
 
 
126
  stream=False
127
  )
128
 
129
  full_text = response.text
130
 
131
- # Parse response
132
- summary_end = full_text.find("- JSON:")
133
- summary = full_text[full_text.find("- Summary:") + len("- Summary:"):summary_end].strip() if "- Summary:" in full_text else "Summary not generated."
134
- qa_json = full_text[summary_end + len("- JSON:"):].strip()
135
-
136
- # Clean up the JSON string to make it parseable
137
- qa_json = qa_json.replace("``````", "").strip()
138
 
139
- try:
140
- qa = json.loads(qa_json)
141
- except json.JSONDecodeError:
142
- print(f"JSON parse error. Raw string: {qa_json}")
143
- qa = {"error": "Failed to parse JSON response.", "raw_text": qa_json}
 
 
 
 
 
 
 
144
 
145
- return summary, json.dumps(qa, indent=2)
146
  except Exception as e:
147
  print(f"RAG error: {e}")
148
- return "Error generating response.", json.dumps({"error": f"Failed to generate response: {str(e)}"})
149
 
150
  def respond(message, history, system_message="You are a news summarizer and Q&A assistant.", max_tokens=512, temperature=0.7, top_p=0.95):
151
  articles = search_articles(message)
 
8
  import os
9
  import json
10
 
11
+ # API Key Validation
12
  api_key = os.getenv("GEMINI_API_KEY")
13
  if not api_key:
14
  raise ValueError("GEMINI_API_KEY environment variable not set")
 
25
  "Coral reefs face bleaching from rising ocean temperatures."
26
  ]
27
 
28
+ # Generate embeddings - with corrected model name
29
+ embedding_model = "models/embedding-001" # Correct model name
30
  df = pd.DataFrame({"article": articles})
31
 
32
  @retry(tries=3, delay=2, backoff=2)
33
  def get_embedding(text):
34
  try:
35
  result = genai.embed_content(model=embedding_model, content=text, task_type="RETRIEVAL_DOCUMENT")
36
+ # Correct way to access embedding
37
  embedding = result.embedding
38
  return embedding
39
  except Exception as e:
40
  print(f"Embedding error: {e}")
41
  raise
42
 
43
+ # Generate embeddings and ensure they're in the correct format
44
+ df["embedding"] = df["article"].apply(get_embedding)
 
 
 
 
 
 
 
 
 
45
 
46
+ # Initialize ChromaDB with proper error handling
47
  client_db = chromadb.Client()
48
  collection = client_db.get_or_create_collection("news_articles")
49
 
50
  # Clear existing data to avoid duplicates
51
  try:
52
  collection.delete(ids=[str(i) for i in range(len(df))])
53
+ except Exception:
54
  pass # Collection might be empty
55
 
56
+ # Add documents with error handling
57
  for idx, row in df.iterrows():
58
+ try:
59
+ collection.add(
60
+ documents=[row["article"]],
61
+ embeddings=[row["embedding"]],
62
+ ids=[str(idx)]
63
+ )
64
+ except Exception as e:
65
+ print(f"Error adding document {idx}: {e}")
66
 
67
  # Semantic Search
68
  @retry(tries=3, delay=2, backoff=2)
 
70
  try:
71
  query_embedding = get_embedding(query)
72
  results = collection.query(query_embeddings=[query_embedding], n_results=top_k)
73
+ if not results["ids"][0]:
74
+ return []
75
  indices = [int(id) for id in results["ids"][0]]
76
  return df.iloc[indices]["article"].tolist()
77
  except Exception as e:
 
79
  return []
80
 
81
  # RAG and Structured Q&A
82
+ generation_model = genai.GenerativeModel("gemini-1.5-pro") # Corrected model name
83
 
84
  @retry(tries=3, delay=2, backoff=2)
85
  def generate_response(query, articles, system_message):
 
87
  return "No relevant articles found.", json.dumps({"error": "No relevant articles found."})
88
 
89
  context = "\n".join(articles)
 
 
 
 
 
 
 
90
  prompt = f"""
91
  {system_message}
92
  Based on the following articles, provide a concise summary (under 100 words) and a structured JSON response with 'question', 'answer', and 'source'. Use only the provided context.
 
100
  - Summary:
101
  - JSON:
102
  """
103
+
104
  try:
 
 
 
 
 
 
 
105
  response = generation_model.generate_content(
106
  prompt,
107
+ generation_config={
108
+ "temperature": 0.7,
109
+ "top_p": 0.95,
110
+ "max_output_tokens": 1024,
111
+ },
112
  stream=False
113
  )
114
 
115
  full_text = response.text
116
 
117
+ # Robust parsing
118
+ summary = "Summary not generated."
119
+ if "- Summary:" in full_text:
120
+ summary_start = full_text.find("- Summary:") + len("- Summary:")
121
+ summary_end = full_text.find("- JSON:", summary_start)
122
+ if summary_end > summary_start:
123
+ summary = full_text[summary_start:summary_end].strip()
124
 
125
+ qa_json = "{}"
126
+ if "- JSON:" in full_text:
127
+ json_start = full_text.find("- JSON:") + len("- JSON:")
128
+ qa_json_text = full_text[json_start:].strip()
129
+ # Clean up the JSON string - remove markdown code blocks
130
+ qa_json_text = qa_json_text.replace("``````", "").strip()
131
+
132
+ try:
133
+ qa = json.loads(qa_json_text)
134
+ qa_json = json.dumps(qa, indent=2)
135
+ except json.JSONDecodeError:
136
+ qa_json = json.dumps({"error": "Failed to parse JSON response.", "raw_text": qa_json_text})
137
 
138
+ return summary, qa_json
139
  except Exception as e:
140
  print(f"RAG error: {e}")
141
+ return f"Error generating response: {str(e)}", json.dumps({"error": f"Failed to generate response: {str(e)}"})
142
 
143
  def respond(message, history, system_message="You are a news summarizer and Q&A assistant.", max_tokens=512, temperature=0.7, top_p=0.95):
144
  articles = search_articles(message)