anthonym21 commited on
Commit
2ce4d6f
·
1 Parent(s): c3f6521

Switch to HF Inference API - no ZeroGPU needed

Browse files
Files changed (3) hide show
  1. README.md +0 -1
  2. app.py +28 -38
  3. requirements.txt +3 -6
README.md CHANGED
@@ -9,7 +9,6 @@ app_file: app.py
9
  pinned: false
10
  license: mit
11
  short_description: Chat with Anthony Maio's AI safety research papers
12
- hardware: zero-a10g
13
  ---
14
 
15
  # Ask My Research
 
9
  pinned: false
10
  license: mit
11
  short_description: Chat with Anthony Maio's AI safety research papers
 
12
  ---
13
 
14
  # Ask My Research
app.py CHANGED
@@ -1,20 +1,19 @@
1
  """
2
  Ask My Research - RAG chatbot over Anthony Maio's AI safety papers.
3
- Runs on HuggingFace Spaces with ZeroGPU.
4
  """
5
 
6
  import json
 
7
  import time
8
  from pathlib import Path
9
  from collections import defaultdict
10
 
11
  import gradio as gr
12
  import numpy as np
13
- import spaces
14
- import torch
15
- import faiss
16
  from sentence_transformers import SentenceTransformer
17
- from transformers import AutoModelForCausalLM, AutoTokenizer
18
 
19
  # =============================================================================
20
  # Configuration
@@ -102,10 +101,15 @@ else:
102
  faiss_index = None
103
  chunks = []
104
 
105
- print("Loading LLM tokenizer...")
106
- tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
107
- if tokenizer.pad_token is None:
108
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
 
109
 
110
  # =============================================================================
111
  # RAG Functions
@@ -179,20 +183,11 @@ def format_citations(retrieved_chunks: list[dict]) -> str:
179
 
180
 
181
  # =============================================================================
182
- # Generation with ZeroGPU
183
  # =============================================================================
184
 
185
- @spaces.GPU(duration=120)
186
  def generate_response(query: str, context: str) -> str:
187
- """Generate response using the LLM with ZeroGPU."""
188
-
189
- # Load model on GPU
190
- model = AutoModelForCausalLM.from_pretrained(
191
- LLM_MODEL,
192
- torch_dtype=torch.float16,
193
- device_map="auto",
194
- trust_remote_code=True
195
- )
196
 
197
  # Build prompt
198
  system_prompt = """You are a helpful research assistant that answers questions about Anthony Maio's AI safety research papers.
@@ -211,24 +206,19 @@ Question: {query}
211
 
212
  Provide a helpful answer based ONLY on the context above. If the context doesn't contain relevant information, say so."""
213
 
214
- messages = [
215
- {"role": "user", "content": f"{system_prompt}\n\n{user_prompt}"}
216
- ]
217
-
218
- # Generate
219
- inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
220
-
221
- with torch.no_grad():
222
- outputs = model.generate(
223
- inputs,
224
- max_new_tokens=MAX_NEW_TOKENS,
225
- do_sample=True,
226
- temperature=0.7,
227
- top_p=0.9,
228
- pad_token_id=tokenizer.pad_token_id,
229
- )
230
 
231
- response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
232
  return response.strip()
233
 
234
 
@@ -263,7 +253,7 @@ def chat(message: str, history: list, request: gr.Request) -> str:
263
  try:
264
  response = generate_response(message, context)
265
  except Exception as e:
266
- return f"Error generating response: {str(e)}"
267
 
268
  # Add citations
269
  citations = format_citations(retrieved)
 
1
  """
2
  Ask My Research - RAG chatbot over Anthony Maio's AI safety papers.
3
+ Runs on HuggingFace Spaces using the Inference API.
4
  """
5
 
6
  import json
7
+ import os
8
  import time
9
  from pathlib import Path
10
  from collections import defaultdict
11
 
12
  import gradio as gr
13
  import numpy as np
14
+ from huggingface_hub import InferenceClient
 
 
15
  from sentence_transformers import SentenceTransformer
16
+ import faiss
17
 
18
  # =============================================================================
19
  # Configuration
 
101
  faiss_index = None
102
  chunks = []
103
 
104
+ # Initialize the Inference Client
105
+ print("Initializing HF Inference Client...")
106
+ hf_token = os.environ.get("HF_TOKEN")
107
+ if hf_token:
108
+ client = InferenceClient(token=hf_token)
109
+ print("Inference client ready with authentication")
110
+ else:
111
+ client = InferenceClient()
112
+ print("WARNING: No HF_TOKEN found - using unauthenticated requests")
113
 
114
  # =============================================================================
115
  # RAG Functions
 
183
 
184
 
185
  # =============================================================================
186
+ # Generation with Inference API
187
  # =============================================================================
188
 
 
189
  def generate_response(query: str, context: str) -> str:
190
+ """Generate response using the HF Inference API."""
 
 
 
 
 
 
 
 
191
 
192
  # Build prompt
193
  system_prompt = """You are a helpful research assistant that answers questions about Anthony Maio's AI safety research papers.
 
206
 
207
  Provide a helpful answer based ONLY on the context above. If the context doesn't contain relevant information, say so."""
208
 
209
+ # Format for Mistral instruction format
210
+ prompt = f"<s>[INST] {system_prompt}\n\n{user_prompt} [/INST]"
211
+
212
+ # Call the Inference API
213
+ response = client.text_generation(
214
+ prompt,
215
+ model=LLM_MODEL,
216
+ max_new_tokens=MAX_NEW_TOKENS,
217
+ temperature=0.7,
218
+ top_p=0.9,
219
+ repetition_penalty=1.1,
220
+ )
 
 
 
 
221
 
 
222
  return response.strip()
223
 
224
 
 
253
  try:
254
  response = generate_response(message, context)
255
  except Exception as e:
256
+ return f"Error generating response: {type(e).__name__}: {str(e)}"
257
 
258
  # Add citations
259
  citations = format_citations(retrieved)
requirements.txt CHANGED
@@ -1,8 +1,5 @@
1
- gradio>=4.44.0
2
- transformers>=4.40.0
3
- torch>=2.0.0
4
  sentence-transformers>=2.2.0
5
  faiss-cpu>=1.7.4
6
- PyMuPDF>=1.23.0
7
- accelerate>=0.27.0
8
- spaces>=0.28.0
 
1
+ gradio>=5.0.0
2
+ huggingface_hub>=0.20.0
 
3
  sentence-transformers>=2.2.0
4
  faiss-cpu>=1.7.4
5
+ numpy