raksama19 commited on
Commit
3f965c2
·
verified ·
1 Parent(s): 31f1c8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -47
app.py CHANGED
@@ -120,7 +120,7 @@ class DOLPHIN:
120
  do_sample=False,
121
  num_beams=1,
122
  repetition_penalty=1.1,
123
- temperature=0.2
124
  )
125
 
126
  sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
@@ -185,15 +185,13 @@ Provide a descriptive alt text in 1-2 sentences that is informative but not over
185
  )
186
  input_len = input_ids["input_ids"].shape[-1]
187
 
188
- input_ids = input_ids.to(self.model.device)
189
  outputs = self.model.generate(
190
  **input_ids,
191
  max_new_tokens=256,
192
  disable_compile=True,
193
  do_sample=False,
194
- temperature=0.2,
195
- pad_token_id=self.processor.tokenizer.pad_token_id,
196
- eos_token_id=self.processor.tokenizer.eos_token_id
197
  )
198
 
199
  text = self.processor.batch_decode(
@@ -246,15 +244,13 @@ Provide a descriptive alt text in 1-2 sentences that is informative but not over
246
  )
247
  input_len = input_ids["input_ids"].shape[-1]
248
 
249
- input_ids = input_ids.to(self.model.device)
250
  outputs = self.model.generate(
251
  **input_ids,
252
  max_new_tokens=1024,
253
  disable_compile=True,
254
- do_sample=False,
255
- temperature=0.2,
256
- pad_token_id=self.processor.tokenizer.pad_token_id,
257
- eos_token_id=self.processor.tokenizer.eos_token_id
258
  )
259
 
260
  text = self.processor.batch_decode(
@@ -690,7 +686,7 @@ def create_embeddings(chunks):
690
  def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
691
  """Retrieve most relevant chunks for a question"""
692
  if embedding_model is None or embeddings is None:
693
- return chunks[:3] # Fallback to first 3 chunks
694
 
695
  try:
696
  question_embedding = embedding_model.encode([question], show_progress_bar=False)
@@ -982,49 +978,31 @@ with gr.Blocks(
982
  return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Please process a PDF document first before asking questions."}]
983
 
984
  try:
985
- # Check if it's a simple greeting or conversational message
986
- greeting_words = ['hi', 'hello', 'hey', 'good morning', 'good afternoon', 'good evening', 'thanks', 'thank you']
987
- is_greeting = any(greeting.lower() in message.lower() for greeting in greeting_words)
988
-
989
- if is_greeting and len(message.split()) <= 3:
990
- # Handle simple greetings without RAG
991
- if 'hi' in message.lower() or 'hello' in message.lower() or 'hey' in message.lower():
992
- response_text = "Hello! I'm here to help you with questions about your processed document. What would you like to know?"
993
- elif 'thank' in message.lower():
994
- response_text = "You're welcome! Feel free to ask me anything about the document."
995
- else:
996
- response_text = "Hello! How can I help you understand the document better?"
997
  else:
998
- # Use RAG for document-related questions
999
- if document_chunks and len(document_chunks) > 0:
1000
- relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings, top_k=3)
1001
- context = "\n\n".join(relevant_chunks)
1002
- # Smart truncation: aim for ~6000 chars for local model
1003
- if len(context) > 6000:
1004
- # Try to cut at sentence boundaries
1005
- sentences = context[:6000].split('.')
1006
- context = '.'.join(sentences[:-1]) + '...' if len(sentences) > 1 else context[:6000] + '...'
1007
- else:
1008
- # Fallback to truncated document if RAG fails
1009
- context = processed_markdown[:6000] + "..." if len(processed_markdown) > 6000 else processed_markdown
1010
-
1011
- # Create prompt for Gemma 3n
1012
- prompt = f"""You are a helpful assistant that answers questions about documents. Answer concisely and directly based on the provided context. If the context doesn't contain relevant information, say so briefly and offer to help with other questions about the document.
1013
 
1014
  Context from the document:
1015
  {context}
1016
 
1017
  Question: {message}
1018
 
1019
- Answer:"""
1020
-
1021
- # Generate response using local Gemma 3n
1022
- response_text = gemma_model.chat(prompt)
1023
-
1024
- # Clean up repetitive text and Korean characters
1025
- response_text = response_text.split('답변:')[0].strip() # Remove Korean repetitions
1026
- response_text = response_text.split('Answer:')[-1].strip() # Clean prompt artifacts
1027
-
1028
  return history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}]
1029
 
1030
  except Exception as e:
 
120
  do_sample=False,
121
  num_beams=1,
122
  repetition_penalty=1.1,
123
+ temperature=1.0
124
  )
125
 
126
  sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
 
185
  )
186
  input_len = input_ids["input_ids"].shape[-1]
187
 
188
+ input_ids = input_ids.to(self.model.device, dtype=self.model.dtype)
189
  outputs = self.model.generate(
190
  **input_ids,
191
  max_new_tokens=256,
192
  disable_compile=True,
193
  do_sample=False,
194
+ temperature=0.1
 
 
195
  )
196
 
197
  text = self.processor.batch_decode(
 
244
  )
245
  input_len = input_ids["input_ids"].shape[-1]
246
 
247
+ input_ids = input_ids.to(self.model.device, dtype=self.model.dtype)
248
  outputs = self.model.generate(
249
  **input_ids,
250
  max_new_tokens=1024,
251
  disable_compile=True,
252
+ do_sample=True,
253
+ temperature=0.7
 
 
254
  )
255
 
256
  text = self.processor.batch_decode(
 
686
  def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
687
  """Retrieve most relevant chunks for a question"""
688
  if embedding_model is None or embeddings is None:
689
+ return chunks[:3] # Fallback to first 3 chunks
690
 
691
  try:
692
  question_embedding = embedding_model.encode([question], show_progress_bar=False)
 
978
  return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Please process a PDF document first before asking questions."}]
979
 
980
  try:
981
+ # Use RAG to get relevant chunks from markdown
982
+ if document_chunks and len(document_chunks) > 0:
983
+ relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings, top_k=3)
984
+ context = "\n\n".join(relevant_chunks)
985
+ # Smart truncation: aim for ~6000 chars for local model
986
+ if len(context) > 6000:
987
+ # Try to cut at sentence boundaries
988
+ sentences = context[:6000].split('.')
989
+ context = '.'.join(sentences[:-1]) + '...' if len(sentences) > 1 else context[:6000] + '...'
 
 
 
990
  else:
991
+ # Fallback to truncated document if RAG fails
992
+ context = processed_markdown[:6000] + "..." if len(processed_markdown) > 6000 else processed_markdown
993
+
994
+ # Create prompt for Gemma 3n
995
+ prompt = f"""You are a helpful assistant that answers questions about documents. Use the provided context to answer questions accurately and concisely.
 
 
 
 
 
 
 
 
 
 
996
 
997
  Context from the document:
998
  {context}
999
 
1000
  Question: {message}
1001
 
1002
+ Please provide a clear and helpful answer based on the context provided."""
1003
+
1004
+ # Generate response using local Gemma 3n
1005
+ response_text = gemma_model.chat(prompt)
 
 
 
 
 
1006
  return history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}]
1007
 
1008
  except Exception as e: