jme-datasci commited on
Commit
3ceae4d
·
1 Parent(s): 0166d40

optimize for qwen

Browse files
Files changed (1) hide show
  1. app.py +85 -98
app.py CHANGED
@@ -1,156 +1,143 @@
1
  import os
2
  import torch
3
  import gradio as gr
4
- import faiss
5
- import numpy as np
6
- from tqdm import tqdm
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from langchain_huggingface import HuggingFaceEmbeddings
9
  from langchain_community.vectorstores import FAISS
10
- import spaces
11
 
12
- # Ensure an HF Token is present for gated models (like Llama 3)
13
  HF_TOKEN = os.getenv("HF_TOKEN")
14
- rag_pipeline = None
15
  class MyRAGPipeline:
16
- '''
17
- Wrapper class for RAG pipeline.
18
- '''
19
- def __init__(self, model_name: str, embedding_model_name: str, vector_db_path: str, tokenizer_name=None, MAX_NEW_TOKENS=500, TEMPERATURE=0.7, DO_SAMPLE=True):
20
- if tokenizer_name is None:
21
- tokenizer_name = model_name
22
-
23
  self.embedding_model_name = embedding_model_name
24
- self.max_new_tokens = MAX_NEW_TOKENS
25
 
26
  print(f"Loading Model: {model_name}...")
27
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=HF_TOKEN)
 
 
 
 
28
  self.model = AutoModelForCausalLM.from_pretrained(
29
  model_name,
30
- device_map="auto",
31
- dtype=torch.bfloat16,
32
  token=HF_TOKEN
33
  )
 
34
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
35
  self.tokenizer.padding_side = "left"
36
 
37
  print("Loading Embeddings...")
38
  self.embedding_model = HuggingFaceEmbeddings(
39
  model_name=self.embedding_model_name,
40
- multi_process=False, # Set to False for stability in Spaces
41
- model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
42
  encode_kwargs={"normalize_embeddings": True},
43
  )
44
 
45
  print(f"Loading Vector DB from {vector_db_path}...")
46
- # Check if index exists to prevent crash
47
  if not os.path.exists(vector_db_path):
48
  raise FileNotFoundError(f"Could not find vector DB at {vector_db_path}. Please upload your 'index' folder.")
49
 
50
  self.vector_db = FAISS.load_local(vector_db_path, self.embedding_model, allow_dangerous_deserialization=True)
51
-
52
- # FAISS GPU optimization (If available)
53
- if torch.cuda.is_available():
54
- try:
55
- res = faiss.StandardGpuResources()
56
- co = faiss.GpuClonerOptions()
57
- co.useFloat16 = True
58
- self.vector_db.index = faiss.index_cpu_to_gpu(res, 0, self.vector_db.index, co)
59
- except Exception as e:
60
- print(f"Could not load FAISS to GPU, running on CPU: {e}")
61
-
62
- # Initialize Pipeline
63
- self.pipe = pipeline(
64
- 'text-generation',
65
- model=self.model,
66
- torch_dtype=torch.bfloat16,
67
- device_map='auto',
68
- tokenizer=self.tokenizer,
69
- max_new_tokens=self.max_new_tokens,
70
- temperature=TEMPERATURE,
71
- do_sample=DO_SAMPLE,
72
- pad_token_id=self.tokenizer.eos_token_id,
73
- # return_full_text=False is CRITICAL for chatbots so it doesn't repeat the prompt
74
- return_full_text=False
75
- )
76
 
77
  def retrieve(self, query, num_docs=3):
78
- '''
79
- Returns the k most similar documents to the query
80
- '''
81
- retrieved_docs = self.vector_db.similarity_search(query, k=num_docs)
82
- return retrieved_docs
83
 
84
  def _format_prompt(self, query, retrieved_docs):
85
- context = "\nExtracted documents:\n"
86
- # Adjusted extraction slightly to handle missing metadata keys gracefully
87
  for doc in retrieved_docs:
88
  section = doc.metadata.get('Section', 'N/A')
89
  subtitle = doc.metadata.get('Subtitle', 'Context')
90
  context += f"{section} - {subtitle}:::\n{doc.page_content}\n\n"
91
 
92
- prompt = f'''
93
- You are a helpful legal interpreter.
94
- You are given the following context:
95
- {context}\n\n
96
- Using the information contained in the context,
97
- give a comprehensive answer to the question.
98
- Respond only to the question asked. Your response should be concise and relevant to the question.
99
- Always provide the section number and title of the source document.
100
- Also please use plain English when responding, not legal jargon.
 
 
101
 
102
- Question: {query}"
103
- '''
 
 
 
104
  return prompt
105
 
106
- def easy_generate(self, query, num_docs=3):
107
- retrieved_docs = self.retrieve(query, num_docs=num_docs)
108
- prompt = self._format_prompt(query, retrieved_docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- # Because we used return_full_text=False in the pipeline,
111
- # this returns only the answer.
112
- result = self.pipe(prompt)[0]['generated_text']
113
- return result
114
 
115
- # --- INITIALIZATION ---
116
- # Using standard paths and models
117
- MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
118
  EMBEDDING_NAME = 'Qwen/Qwen3-Embedding-0.6B'
119
- VECDB_PATH = './index/'
120
 
121
- # Initialize the RAG system globally so it doesn't reload on every message
 
122
  try:
123
  rag = MyRAGPipeline(MODEL_NAME, EMBEDDING_NAME, VECDB_PATH)
124
  except Exception as e:
 
125
  rag = None
126
- print(f"Error initializing RAG: {e}")
127
 
128
- # --- GRADIO INTERFACE ---
129
- @spaces.GPU(duration=120)
130
  def chat_function(message, history):
131
- global rag_pipeline
 
132
 
133
- # Initialize ONLY when the GPU is assigned
134
- if rag_pipeline is None:
135
- print("Initializing RAG Pipeline on GPU...")
136
- rag_pipeline = MyRAGPipeline(MODEL_NAME, EMBEDDING_NAME, VECDB_PATH)
137
-
138
- return rag_pipeline.easy_generate(message)
 
 
 
 
 
 
139
 
140
  demo = gr.ChatInterface(
141
- fn=chat_function,
142
  type="messages",
143
- title="Charlottesville Municipal RAG Assistant",
144
- description="Ask a question about the City of Charlottesville municipal code. This app is intended to increase accessibility to the municipal code and is not a replacement for a legal professional. AI makes mistakes, check important information.",
145
- examples=[
146
- "My neighbor is playing loud music on their porch. What time does the 'quiet period' start, and what is the maximum decibel level allowed in a residential zone?",
147
- "There is a massive oak tree on my property I want to cut down. Do I need permission from the city to remove it?",
148
- "I got a parking ticket near the Downtown Mall. What is the deadline to pay the fine, and how do I contest it if I think it was issued in error?",
149
- "I want to build a privacy fence in my backyard. How tall can it be before I need a permit, and are there different rules for the front yard versus the back yard?",
150
- "I found a deer in my backyard. Can I keep it as a pet if I put a leash on it?",
151
- "I'm having trouble catching fish in the Rivanna River. Is it legal to use explosives to help catch them?",
152
- "Can I legally attach a flamethrower to my car to melt the snow on my driveway?",
153
- "Is it legal for me to practice my bagpipes on the sidewalk at 2:00 AM if I'm technically walking and not 'loitering'?"]
154
  )
155
 
156
  if __name__ == "__main__":
 
1
  import os
2
  import torch
3
  import gradio as gr
4
+ import spaces
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
 
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from langchain_community.vectorstores import FAISS
8
+ from threading import Thread
9
 
 
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
+
12
  class MyRAGPipeline:
13
+ def __init__(self, model_name: str, embedding_model_name: str, vector_db_path: str):
 
 
 
 
 
 
14
  self.embedding_model_name = embedding_model_name
15
+ self.max_new_tokens = 500
16
 
17
  print(f"Loading Model: {model_name}...")
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
19
+
20
+ # --- CRITICAL: Load to CPU first ---
21
+ # ZeroGPU does not have a GPU available during global startup.
22
+ # We load the weights into System RAM now, and move them to GPU later.
23
  self.model = AutoModelForCausalLM.from_pretrained(
24
  model_name,
25
+ device_map="cpu", # Force CPU loading
26
+ torch_dtype=torch.bfloat16,
27
  token=HF_TOKEN
28
  )
29
+
30
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
31
  self.tokenizer.padding_side = "left"
32
 
33
  print("Loading Embeddings...")
34
  self.embedding_model = HuggingFaceEmbeddings(
35
  model_name=self.embedding_model_name,
36
+ model_kwargs={"device": "cpu"}, # Keep embeddings on CPU
 
37
  encode_kwargs={"normalize_embeddings": True},
38
  )
39
 
40
  print(f"Loading Vector DB from {vector_db_path}...")
 
41
  if not os.path.exists(vector_db_path):
42
  raise FileNotFoundError(f"Could not find vector DB at {vector_db_path}. Please upload your 'index' folder.")
43
 
44
  self.vector_db = FAISS.load_local(vector_db_path, self.embedding_model, allow_dangerous_deserialization=True)
45
+ print("RAG Pipeline Initialized (CPU Mode)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def retrieve(self, query, num_docs=3):
48
+ return self.vector_db.similarity_search(query, k=num_docs)
 
 
 
 
49
 
50
  def _format_prompt(self, query, retrieved_docs):
51
+ # 1. Build Context
52
+ context = "Extracted documents:\n"
53
  for doc in retrieved_docs:
54
  section = doc.metadata.get('Section', 'N/A')
55
  subtitle = doc.metadata.get('Subtitle', 'Context')
56
  context += f"{section} - {subtitle}:::\n{doc.page_content}\n\n"
57
 
58
+ # 2. Universal Chat Template (Works for Qwen, Llama, etc.)
59
+ messages = [
60
+ {
61
+ "role": "system",
62
+ "content": f"You are a helpful legal interpreter. Use the following context to answer the user's question.\nContext:\n{context}"
63
+ },
64
+ {
65
+ "role": "user",
66
+ "content": query
67
+ }
68
+ ]
69
 
70
+ prompt = self.tokenizer.apply_chat_template(
71
+ messages,
72
+ tokenize=False,
73
+ add_generation_prompt=True
74
+ )
75
  return prompt
76
 
77
+ def generate(self, query, num_docs=3):
78
+ # 1. Retrieve
79
+ retrieved_docs = self.retrieve(query, num_docs)
80
+
81
+ # 2. Format Prompt
82
+ prompt_str = self._format_prompt(query, retrieved_docs)
83
+
84
+ # 3. Tokenize
85
+ inputs = self.tokenizer(prompt_str, return_tensors="pt").to(self.model.device)
86
+
87
+ # 4. Generate (Streaming is simpler for direct model usage, but here we do blocking)
88
+ with torch.no_grad():
89
+ outputs = self.model.generate(
90
+ **inputs,
91
+ max_new_tokens=self.max_new_tokens,
92
+ temperature=0.7,
93
+ do_sample=True,
94
+ pad_token_id=self.tokenizer.eos_token_id
95
+ )
96
+
97
+ # 5. Decode
98
+ # Slicing [input_len:] ensures we only return the new text, not the prompt
99
+ input_len = inputs.input_ids.shape[1]
100
+ generated_text = self.tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
101
 
102
+ return generated_text
 
 
 
103
 
104
+ # --- CONFIGURATION ---
105
+ MODEL_NAME = 'Qwen/Qwen2.5-7B-Instruct'
 
106
  EMBEDDING_NAME = 'Qwen/Qwen3-Embedding-0.6B'
107
+ VECDB_PATH = 'index/'
108
 
109
+ # --- GLOBAL INSTANTIATION ---
110
+ # This runs once when the container starts.
111
  try:
112
  rag = MyRAGPipeline(MODEL_NAME, EMBEDDING_NAME, VECDB_PATH)
113
  except Exception as e:
114
+ print(f"Initialization Error: {e}")
115
  rag = None
 
116
 
117
+ # --- ZERO-GPU INFERENCE FUNCTION ---
118
+ @spaces.GPU
119
  def chat_function(message, history):
120
+ if rag is None:
121
+ return "System Error: RAG Pipeline failed to initialize."
122
 
123
+ # 1. Move Model to GPU (Fast operation on ZeroGPU)
124
+ print("Moving model to GPU...")
125
+ rag.model.to("cuda")
126
+
127
+ # 2. Generate
128
+ response = rag.generate(message)
129
+
130
+ # 3. (Optional) Move back to CPU to save VRAM?
131
+ # Usually not needed as ZeroGPU handles cleanup, but good practice if sharing resources.
132
+ # rag.model.to("cpu")
133
+
134
+ return response
135
 
136
  demo = gr.ChatInterface(
137
+ fn=chat_function,
138
  type="messages",
139
+ title="Legal RAG Assistant (Qwen 2.5)",
140
+ description="Ask a question about the legal documents.",
 
 
 
 
 
 
 
 
 
141
  )
142
 
143
  if __name__ == "__main__":