NeoPy commited on
Commit
f7da735
·
verified ·
1 Parent(s): cb15ef3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -97
app.py CHANGED
@@ -1,44 +1,110 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
 
 
 
 
 
 
 
 
4
  from threading import Thread
5
  from sentence_transformers import SentenceTransformer, util
 
 
 
6
 
7
  # --- CONFIGURATION ---
8
- # Loading the tokenizer and model from Hugging Face's model hub.
 
9
  print("Loading TinyLlama...")
10
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
11
  model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
12
 
13
- # Loading the Embedding model for RAG
14
  print("Loading Embedding Model...")
15
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
16
 
17
- # using CUDA for an optimal experience
 
 
 
 
 
 
18
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
  model = model.to(device)
 
20
 
21
  # --- GLOBAL STATE FOR RAG ---
22
  KNOWLEDGE_CHUNKS = []
23
  KNOWLEDGE_EMBEDDINGS = None
24
  RAG_ENABLED = False
25
 
26
- # System content - Define the assistant's personality and capabilities
27
  DEFAULT_SYSTEM_PROMPT = """You are TinyLlama, a friendly and helpful AI assistant.
28
- You are based on the TinyLlama-1.1B-Chat model and you excel at providing clear,
29
- concise answers to various questions."""
30
-
31
  SYSTEM_CONTENT = DEFAULT_SYSTEM_PROMPT
32
 
33
- # Defining a custom stopping criteria class for the model's text generation.
34
  class StopOnTokens(StoppingCriteria):
35
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
36
- stop_ids = [2] # IDs of tokens where the generation should stop.
37
  for stop_id in stop_ids:
38
- if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
39
  return True
40
  return False
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # --- RAG FUNCTIONS ---
43
 
44
  def process_knowledge_base(text_content):
@@ -49,15 +115,14 @@ def process_knowledge_base(text_content):
49
  RAG_ENABLED = False
50
  return "Knowledge base cleared.", False
51
 
52
- # 1. Simple Chunking (Split by paragraphs or roughly by max characters)
53
- # For a real app, use a proper text splitter (like RecursiveCharacterTextSplitter)
54
  raw_chunks = text_content.split('\n\n')
55
  chunks = [chunk.strip() for chunk in raw_chunks if len(chunk.strip()) > 20]
56
 
57
  if not chunks:
58
  return "No valid text found to process.", False
59
 
60
- # 2. Create Embeddings
61
  try:
62
  embeddings = embedder.encode(chunks, convert_to_tensor=True)
63
 
@@ -65,27 +130,21 @@ def process_knowledge_base(text_content):
65
  KNOWLEDGE_EMBEDDINGS = embeddings
66
  RAG_ENABLED = True
67
 
68
- return f"Successfully indexed {len(chunks)} text chunks.", True
69
  except Exception as e:
70
  return f"Error creating embeddings: {str(e)}", False
71
 
72
  def retrieve_context(query, top_k=3):
73
- """Finds relevant chunks for the query."""
74
  if not RAG_ENABLED or KNOWLEDGE_EMBEDDINGS is None:
75
  return ""
76
 
77
- # Encode user query
78
  query_embedding = embedder.encode(query, convert_to_tensor=True)
79
-
80
- # Compute Cosine Similarity
81
  cos_scores = util.cos_sim(query_embedding, KNOWLEDGE_EMBEDDINGS)[0]
82
-
83
- # Get top_k results
84
  top_results = torch.topk(cos_scores, k=min(top_k, len(KNOWLEDGE_CHUNKS)))
85
 
86
  retrieved_text = []
87
  for score, idx in zip(top_results[0], top_results[1]):
88
- if score > 0.3: # Threshold to ensure relevance
89
  retrieved_text.append(KNOWLEDGE_CHUNKS[idx])
90
 
91
  return "\n\n".join(retrieved_text)
@@ -93,30 +152,23 @@ def retrieve_context(query, top_k=3):
93
  # --- PREDICTION FUNCTION ---
94
 
95
  def predict(message, history, system_content=None):
96
- # Use custom system content if provided, otherwise use default
97
  current_system_content = system_content if system_content else SYSTEM_CONTENT
98
 
99
- # --- RAG LOGIC ---
100
  context = ""
101
  if RAG_ENABLED:
102
  retrieved = retrieve_context(message)
103
  if retrieved:
104
- context = f"\nUse the following context to answer the user's question:\n{retrieved}\n"
105
- # We modify the prompt to include the context
106
  message = f"{context}\nQuestion: {message}"
107
- # -----------------
108
 
109
  history_transformer_format = history + [[message, ""]]
110
  stop = StopOnTokens()
111
 
112
- # Formatting the input for the model with system content
113
  system_prompt = f"<|system|>\n{current_system_content}</s>"
114
  conversation = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
115
  for item in history_transformer_format])
116
 
117
  messages = system_prompt + conversation
118
-
119
- # Tokenize
120
  model_inputs = tokenizer([messages], return_tensors="pt").to(device)
121
 
122
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
@@ -132,103 +184,104 @@ def predict(message, history, system_content=None):
132
  stopping_criteria=StoppingCriteriaList([stop])
133
  )
134
  t = Thread(target=model.generate, kwargs=generate_kwargs)
135
- t.start() # Starting the generation in a separate thread.
136
 
137
  partial_message = ""
138
  for new_token in streamer:
139
  partial_message += new_token
140
- if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
141
  break
142
  yield partial_message
143
 
144
- # --- UI HANDLERS ---
145
 
146
- def update_system_content(system_content):
147
- global SYSTEM_CONTENT
148
- if system_content.strip():
149
- SYSTEM_CONTENT = system_content
150
- return "System content updated successfully!"
151
- else:
152
- return "Please enter valid system content."
 
153
 
154
- def reset_system_content():
155
- global SYSTEM_CONTENT
156
- SYSTEM_CONTENT = DEFAULT_SYSTEM_PROMPT
157
- return DEFAULT_SYSTEM_PROMPT, "System content reset to default!"
 
 
 
 
158
 
159
  # --- GRADIO INTERFACE ---
160
 
161
- with gr.Blocks(title="TinyLlama ChatBot + RAG") as demo:
162
- gr.Markdown("# 🦙 TinyLlama RAG ChatBot")
163
- gr.Markdown("Chat with TinyLlama-1.1B. Use the **RAG settings** to add your own context (Knowledge Base).")
164
 
165
  with gr.Row():
166
  # Left Column: Chat
167
  with gr.Column(scale=2):
168
- gr.Markdown("### 💬 Chat Interface")
169
  chat_interface = gr.ChatInterface(
170
  predict,
171
- examples=['How to cook a fish?', 'Who is the president of US now?', 'Explain quantum computing simply'],
172
- cache_examples=False
173
  )
174
 
175
- # Right Column: Settings & RAG
176
  with gr.Column(scale=1):
177
 
178
- # RAG Section
179
- with gr.Accordion("📚 RAG / Knowledge Base", open=True):
180
- gr.Markdown("Paste text below to give the AI specific knowledge.")
 
181
  kb_input = gr.Textbox(
182
- label="Reference Text",
183
- lines=8,
184
- placeholder="Paste an article, email, or documentation here...",
185
- info="The AI will search this text to answer your questions."
186
  )
 
 
 
 
 
 
 
 
 
 
187
  with gr.Row():
188
- process_btn = gr.Button("Build Knowledge Base", variant="primary")
189
  rag_status = gr.Checkbox(label="RAG Active", interactive=False, value=False)
190
- kb_output = gr.Textbox(label="Status", interactive=False)
191
-
192
- # System Prompt Section
193
- with gr.Accordion("⚙️ System Personality", open=False):
194
- system_content_input = gr.Textbox(
195
- label="System Content",
196
- value=SYSTEM_CONTENT,
197
- lines=4
198
- )
199
- with gr.Row():
200
- update_btn = gr.Button("Update System")
201
- reset_btn = gr.Button("Reset")
202
- system_status = gr.Textbox(label="Status", interactive=False)
203
-
204
- gr.Markdown("### ℹ️ About")
205
- gr.Markdown("""
206
- **Model:** TinyLlama-1.1B
207
- **RAG:** sentence-transformers (all-MiniLM-L6-v2)
208
-
209
- **How to use RAG:**
210
- 1. Paste text into 'Reference Text'.
211
- 2. Click 'Build Knowledge Base'.
212
- 3. Ask questions about that text.
213
- """)
214
 
215
- # Event Handlers
216
- process_btn.click(
217
- process_knowledge_base,
218
- inputs=[kb_input],
219
- outputs=[kb_output, rag_status]
 
 
 
 
 
 
220
  )
221
 
222
- update_btn.click(
223
- update_system_content,
224
- inputs=[system_content_input],
225
- outputs=[system_status]
 
226
  )
227
-
228
- reset_btn.click(
229
- reset_system_content,
230
- outputs=[system_content_input, system_status]
 
 
231
  )
232
 
233
  if __name__ == "__main__":
234
- demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoTokenizer,
6
+ StoppingCriteria,
7
+ StoppingCriteriaList,
8
+ TextIteratorStreamer,
9
+ BlipProcessor,
10
+ BlipForConditionalGeneration
11
+ )
12
  from threading import Thread
13
  from sentence_transformers import SentenceTransformer, util
14
+ import requests
15
+ from bs4 import BeautifulSoup
16
+ from PIL import Image
17
 
18
  # --- CONFIGURATION ---
19
+
20
+ # 1. LLM: TinyLlama
21
  print("Loading TinyLlama...")
22
  tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
23
  model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
24
 
25
+ # 2. Embedding Model: For Text RAG
26
  print("Loading Embedding Model...")
27
  embedder = SentenceTransformer('all-MiniLM-L6-v2')
28
 
29
+ # 3. Vision Model: BLIP (for Image to Text)
30
+ # We use this to convert images into text descriptions so TinyLlama can "read" them.
31
+ print("Loading Vision Model (BLIP)...")
32
+ vision_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
33
+ vision_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
34
+
35
+ # Device Setup
36
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
37
  model = model.to(device)
38
+ vision_model = vision_model.to(device)
39
 
40
  # --- GLOBAL STATE FOR RAG ---
41
  KNOWLEDGE_CHUNKS = []
42
  KNOWLEDGE_EMBEDDINGS = None
43
  RAG_ENABLED = False
44
 
45
+ # System content
46
  DEFAULT_SYSTEM_PROMPT = """You are TinyLlama, a friendly and helpful AI assistant.
47
+ You are based on the TinyLlama-1.1B-Chat model."""
 
 
48
  SYSTEM_CONTENT = DEFAULT_SYSTEM_PROMPT
49
 
 
50
  class StopOnTokens(StoppingCriteria):
51
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
52
+ stop_ids = [2]
53
  for stop_id in stop_ids:
54
+ if input_ids[0][-1] == stop_id:
55
  return True
56
  return False
57
 
58
+ # --- NEW TOOL FUNCTIONS ---
59
+
60
+ def scrape_wikifandom(url):
61
+ """Scrapes text content from a WikiFandom page."""
62
+ if "fandom.com" not in url:
63
+ return "Error: Please provide a valid URL containing 'fandom.com'"
64
+
65
+ try:
66
+ headers = {'User-Agent': 'Mozilla/5.0'}
67
+ response = requests.get(url, headers=headers)
68
+ if response.status_code != 200:
69
+ return f"Error: Failed to fetch page (Status {response.status_code})"
70
+
71
+ soup = BeautifulSoup(response.content, 'html.parser')
72
+
73
+ # Fandom usually puts the main article text in 'mw-parser-output'
74
+ content_div = soup.find('div', class_='mw-parser-output')
75
+
76
+ if not content_div:
77
+ # Fallback for some wiki layouts
78
+ content_div = soup.find('div', id='content')
79
+
80
+ if not content_div:
81
+ return "Error: Could not parse content from this Fandom page."
82
+
83
+ # Extract paragraphs
84
+ paragraphs = content_div.find_all('p')
85
+ text_content = "\n\n".join([p.get_text() for p in paragraphs if len(p.get_text()) > 50])
86
+
87
+ return text_content
88
+ except Exception as e:
89
+ return f"Error scraping URL: {str(e)}"
90
+
91
+ def process_image_to_text(image):
92
+ """Generates a caption for an image using BLIP."""
93
+ if image is None:
94
+ return ""
95
+
96
+ try:
97
+ # Prepare image
98
+ inputs = vision_processor(image, return_tensors="pt").to(device)
99
+
100
+ # Generate caption
101
+ out = vision_model.generate(**inputs, max_new_tokens=50)
102
+ caption = vision_processor.decode(out[0], skip_special_tokens=True)
103
+
104
+ return f"Image Context: The user uploaded an image that shows {caption}."
105
+ except Exception as e:
106
+ return f"Error processing image: {str(e)}"
107
+
108
  # --- RAG FUNCTIONS ---
109
 
110
  def process_knowledge_base(text_content):
 
115
  RAG_ENABLED = False
116
  return "Knowledge base cleared.", False
117
 
118
+ # Chunking
 
119
  raw_chunks = text_content.split('\n\n')
120
  chunks = [chunk.strip() for chunk in raw_chunks if len(chunk.strip()) > 20]
121
 
122
  if not chunks:
123
  return "No valid text found to process.", False
124
 
125
+ # Create Embeddings
126
  try:
127
  embeddings = embedder.encode(chunks, convert_to_tensor=True)
128
 
 
130
  KNOWLEDGE_EMBEDDINGS = embeddings
131
  RAG_ENABLED = True
132
 
133
+ return f"Indexed {len(chunks)} chunks. RAG Ready.", True
134
  except Exception as e:
135
  return f"Error creating embeddings: {str(e)}", False
136
 
137
  def retrieve_context(query, top_k=3):
 
138
  if not RAG_ENABLED or KNOWLEDGE_EMBEDDINGS is None:
139
  return ""
140
 
 
141
  query_embedding = embedder.encode(query, convert_to_tensor=True)
 
 
142
  cos_scores = util.cos_sim(query_embedding, KNOWLEDGE_EMBEDDINGS)[0]
 
 
143
  top_results = torch.topk(cos_scores, k=min(top_k, len(KNOWLEDGE_CHUNKS)))
144
 
145
  retrieved_text = []
146
  for score, idx in zip(top_results[0], top_results[1]):
147
+ if score > 0.25: # Slightly lower threshold for broader context
148
  retrieved_text.append(KNOWLEDGE_CHUNKS[idx])
149
 
150
  return "\n\n".join(retrieved_text)
 
152
  # --- PREDICTION FUNCTION ---
153
 
154
  def predict(message, history, system_content=None):
 
155
  current_system_content = system_content if system_content else SYSTEM_CONTENT
156
 
 
157
  context = ""
158
  if RAG_ENABLED:
159
  retrieved = retrieve_context(message)
160
  if retrieved:
161
+ context = f"\nUse this context to answer:\n{retrieved}\n"
 
162
  message = f"{context}\nQuestion: {message}"
 
163
 
164
  history_transformer_format = history + [[message, ""]]
165
  stop = StopOnTokens()
166
 
 
167
  system_prompt = f"<|system|>\n{current_system_content}</s>"
168
  conversation = "</s>".join(["</s>".join(["\n<|user|>:" + item[0], "\n<|assistant|>:" + item[1]])
169
  for item in history_transformer_format])
170
 
171
  messages = system_prompt + conversation
 
 
172
  model_inputs = tokenizer([messages], return_tensors="pt").to(device)
173
 
174
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
 
184
  stopping_criteria=StoppingCriteriaList([stop])
185
  )
186
  t = Thread(target=model.generate, kwargs=generate_kwargs)
187
+ t.start()
188
 
189
  partial_message = ""
190
  for new_token in streamer:
191
  partial_message += new_token
192
+ if '</s>' in partial_message:
193
  break
194
  yield partial_message
195
 
196
+ # --- UI LOGIC ---
197
 
198
+ def add_fandom_content(url, current_text):
199
+ """Fetches fandom content and appends it to the textbox."""
200
+ scraped_text = scrape_wikifandom(url)
201
+ if scraped_text.startswith("Error"):
202
+ return current_text, scraped_text # Return error in status
203
+
204
+ new_text = (current_text + "\n\n" + scraped_text).strip()
205
+ return new_text, "Fandom content added to Knowledge Base text area."
206
 
207
+ def add_image_content(image, current_text):
208
+ """Analyzes image and appends description to textbox."""
209
+ description = process_image_to_text(image)
210
+ if description.startswith("Error"):
211
+ return current_text, description
212
+
213
+ new_text = (current_text + "\n\n" + description).strip()
214
+ return new_text, "Image analysis added. RAG now knows what this image looks like."
215
 
216
  # --- GRADIO INTERFACE ---
217
 
218
+ with gr.Blocks(title="TinyLlama Multi-Source RAG") as demo:
219
+ gr.Markdown("# 🦙 TinyLlama RAG (WikiFandom + Images)")
220
+ gr.Markdown("Chat with TinyLlama. Build a knowledge base from text, WikiFandom URLs, or Images.")
221
 
222
  with gr.Row():
223
  # Left Column: Chat
224
  with gr.Column(scale=2):
 
225
  chat_interface = gr.ChatInterface(
226
  predict,
227
+ examples=['Who is in the image?', 'Tell me about the wiki page'],
 
228
  )
229
 
230
+ # Right Column: Tools
231
  with gr.Column(scale=1):
232
 
233
+ # --- RAG INPUTS ---
234
+ with gr.Accordion("📚 Knowledge Sources", open=True):
235
+
236
+ # Main Text Area (Where all data ends up)
237
  kb_input = gr.Textbox(
238
+ label="Compiled Knowledge Base",
239
+ lines=6,
240
+ placeholder="Data from Wiki or Images will appear here...",
241
+ interactive=True
242
  )
243
+
244
+ with gr.Tab("🔗 WikiFandom"):
245
+ url_input = gr.Textbox(label="Fandom URL", placeholder="https://starwars.fandom.com/wiki/Luke_Skywalker")
246
+ scrape_btn = gr.Button("Scrape & Add Text")
247
+
248
+ with gr.Tab("🖼️ Image Support"):
249
+ img_input = gr.Image(type="pil", label="Upload Image")
250
+ img_btn = gr.Button("Analyze & Add Description")
251
+
252
+ # Build Button
253
  with gr.Row():
254
+ process_btn = gr.Button("Build Knowledge Base", variant="primary")
255
  rag_status = gr.Checkbox(label="RAG Active", interactive=False, value=False)
256
+
257
+ status_output = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ # System Prompt
260
+ with gr.Accordion("⚙️ System Settings", open=False):
261
+ system_content_input = gr.Textbox(value=SYSTEM_CONTENT, lines=2, label="System Prompt")
262
+
263
+ # --- EVENT HANDLERS ---
264
+
265
+ # 1. Scrape Fandom -> Append to Textbox
266
+ scrape_btn.click(
267
+ add_fandom_content,
268
+ inputs=[url_input, kb_input],
269
+ outputs=[kb_input, status_output]
270
  )
271
 
272
+ # 2. Analyze Image -> Append to Textbox
273
+ img_btn.click(
274
+ add_image_content,
275
+ inputs=[img_input, kb_input],
276
+ outputs=[kb_input, status_output]
277
  )
278
+
279
+ # 3. Build RAG Index
280
+ process_btn.click(
281
+ process_knowledge_base,
282
+ inputs=[kb_input],
283
+ outputs=[status_output, rag_status]
284
  )
285
 
286
  if __name__ == "__main__":
287
+ demo.launch()