Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -120,7 +120,7 @@ class DOLPHIN:
|
|
| 120 |
do_sample=False,
|
| 121 |
num_beams=1,
|
| 122 |
repetition_penalty=1.1,
|
| 123 |
-
temperature=0
|
| 124 |
)
|
| 125 |
|
| 126 |
sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
|
|
@@ -185,15 +185,13 @@ Provide a descriptive alt text in 1-2 sentences that is informative but not over
|
|
| 185 |
)
|
| 186 |
input_len = input_ids["input_ids"].shape[-1]
|
| 187 |
|
| 188 |
-
input_ids = input_ids.to(self.model.device)
|
| 189 |
outputs = self.model.generate(
|
| 190 |
**input_ids,
|
| 191 |
max_new_tokens=256,
|
| 192 |
disable_compile=True,
|
| 193 |
do_sample=False,
|
| 194 |
-
temperature=0.
|
| 195 |
-
pad_token_id=self.processor.tokenizer.pad_token_id,
|
| 196 |
-
eos_token_id=self.processor.tokenizer.eos_token_id
|
| 197 |
)
|
| 198 |
|
| 199 |
text = self.processor.batch_decode(
|
|
@@ -246,15 +244,13 @@ Provide a descriptive alt text in 1-2 sentences that is informative but not over
|
|
| 246 |
)
|
| 247 |
input_len = input_ids["input_ids"].shape[-1]
|
| 248 |
|
| 249 |
-
input_ids = input_ids.to(self.model.device)
|
| 250 |
outputs = self.model.generate(
|
| 251 |
**input_ids,
|
| 252 |
max_new_tokens=1024,
|
| 253 |
disable_compile=True,
|
| 254 |
-
do_sample=
|
| 255 |
-
temperature=0.
|
| 256 |
-
pad_token_id=self.processor.tokenizer.pad_token_id,
|
| 257 |
-
eos_token_id=self.processor.tokenizer.eos_token_id
|
| 258 |
)
|
| 259 |
|
| 260 |
text = self.processor.batch_decode(
|
|
@@ -690,7 +686,7 @@ def create_embeddings(chunks):
|
|
| 690 |
def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
|
| 691 |
"""Retrieve most relevant chunks for a question"""
|
| 692 |
if embedding_model is None or embeddings is None:
|
| 693 |
-
return chunks[:3]
|
| 694 |
|
| 695 |
try:
|
| 696 |
question_embedding = embedding_model.encode([question], show_progress_bar=False)
|
|
@@ -982,49 +978,31 @@ with gr.Blocks(
|
|
| 982 |
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Please process a PDF document first before asking questions."}]
|
| 983 |
|
| 984 |
try:
|
| 985 |
-
#
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
response_text = "You're welcome! Feel free to ask me anything about the document."
|
| 995 |
-
else:
|
| 996 |
-
response_text = "Hello! How can I help you understand the document better?"
|
| 997 |
else:
|
| 998 |
-
#
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
|
| 1002 |
-
|
| 1003 |
-
if len(context) > 6000:
|
| 1004 |
-
# Try to cut at sentence boundaries
|
| 1005 |
-
sentences = context[:6000].split('.')
|
| 1006 |
-
context = '.'.join(sentences[:-1]) + '...' if len(sentences) > 1 else context[:6000] + '...'
|
| 1007 |
-
else:
|
| 1008 |
-
# Fallback to truncated document if RAG fails
|
| 1009 |
-
context = processed_markdown[:6000] + "..." if len(processed_markdown) > 6000 else processed_markdown
|
| 1010 |
-
|
| 1011 |
-
# Create prompt for Gemma 3n
|
| 1012 |
-
prompt = f"""You are a helpful assistant that answers questions about documents. Answer concisely and directly based on the provided context. If the context doesn't contain relevant information, say so briefly and offer to help with other questions about the document.
|
| 1013 |
|
| 1014 |
Context from the document:
|
| 1015 |
{context}
|
| 1016 |
|
| 1017 |
Question: {message}
|
| 1018 |
|
| 1019 |
-
|
| 1020 |
-
|
| 1021 |
-
|
| 1022 |
-
|
| 1023 |
-
|
| 1024 |
-
# Clean up repetitive text and Korean characters
|
| 1025 |
-
response_text = response_text.split('답변:')[0].strip() # Remove Korean repetitions
|
| 1026 |
-
response_text = response_text.split('Answer:')[-1].strip() # Clean prompt artifacts
|
| 1027 |
-
|
| 1028 |
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}]
|
| 1029 |
|
| 1030 |
except Exception as e:
|
|
|
|
| 120 |
do_sample=False,
|
| 121 |
num_beams=1,
|
| 122 |
repetition_penalty=1.1,
|
| 123 |
+
temperature=1.0
|
| 124 |
)
|
| 125 |
|
| 126 |
sequences = self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)
|
|
|
|
| 185 |
)
|
| 186 |
input_len = input_ids["input_ids"].shape[-1]
|
| 187 |
|
| 188 |
+
input_ids = input_ids.to(self.model.device, dtype=self.model.dtype)
|
| 189 |
outputs = self.model.generate(
|
| 190 |
**input_ids,
|
| 191 |
max_new_tokens=256,
|
| 192 |
disable_compile=True,
|
| 193 |
do_sample=False,
|
| 194 |
+
temperature=0.1
|
|
|
|
|
|
|
| 195 |
)
|
| 196 |
|
| 197 |
text = self.processor.batch_decode(
|
|
|
|
| 244 |
)
|
| 245 |
input_len = input_ids["input_ids"].shape[-1]
|
| 246 |
|
| 247 |
+
input_ids = input_ids.to(self.model.device, dtype=self.model.dtype)
|
| 248 |
outputs = self.model.generate(
|
| 249 |
**input_ids,
|
| 250 |
max_new_tokens=1024,
|
| 251 |
disable_compile=True,
|
| 252 |
+
do_sample=True,
|
| 253 |
+
temperature=0.7
|
|
|
|
|
|
|
| 254 |
)
|
| 255 |
|
| 256 |
text = self.processor.batch_decode(
|
|
|
|
| 686 |
def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
|
| 687 |
"""Retrieve most relevant chunks for a question"""
|
| 688 |
if embedding_model is None or embeddings is None:
|
| 689 |
+
return chunks[:3] # Fallback to first 3 chunks
|
| 690 |
|
| 691 |
try:
|
| 692 |
question_embedding = embedding_model.encode([question], show_progress_bar=False)
|
|
|
|
| 978 |
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": "❌ Please process a PDF document first before asking questions."}]
|
| 979 |
|
| 980 |
try:
|
| 981 |
+
# Use RAG to get relevant chunks from markdown
|
| 982 |
+
if document_chunks and len(document_chunks) > 0:
|
| 983 |
+
relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings, top_k=3)
|
| 984 |
+
context = "\n\n".join(relevant_chunks)
|
| 985 |
+
# Smart truncation: aim for ~6000 chars for local model
|
| 986 |
+
if len(context) > 6000:
|
| 987 |
+
# Try to cut at sentence boundaries
|
| 988 |
+
sentences = context[:6000].split('.')
|
| 989 |
+
context = '.'.join(sentences[:-1]) + '...' if len(sentences) > 1 else context[:6000] + '...'
|
|
|
|
|
|
|
|
|
|
| 990 |
else:
|
| 991 |
+
# Fallback to truncated document if RAG fails
|
| 992 |
+
context = processed_markdown[:6000] + "..." if len(processed_markdown) > 6000 else processed_markdown
|
| 993 |
+
|
| 994 |
+
# Create prompt for Gemma 3n
|
| 995 |
+
prompt = f"""You are a helpful assistant that answers questions about documents. Use the provided context to answer questions accurately and concisely.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 996 |
|
| 997 |
Context from the document:
|
| 998 |
{context}
|
| 999 |
|
| 1000 |
Question: {message}
|
| 1001 |
|
| 1002 |
+
Please provide a clear and helpful answer based on the context provided."""
|
| 1003 |
+
|
| 1004 |
+
# Generate response using local Gemma 3n
|
| 1005 |
+
response_text = gemma_model.chat(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1006 |
return history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}]
|
| 1007 |
|
| 1008 |
except Exception as e:
|