rbbist's picture
Reversing last edit
701e4ef verified
# app.py
import gradio as gr
from chromadb_semantic_search_for_dataset import semantic_search, build_compact_context
from transformers import pipeline, AutoTokenizer, MT5ForConditionalGeneration
import time
import torch
# Try different models - MT5 can be problematic for this task
# Consider using these alternatives:
MODELS_TO_TRY = [
"google/flan-t5-base", # Better for instruction following
"google/mt5-base", # Smaller, more stable than large
# "google/mt5-large" # Your original choice - may have issues
]
SUMMARY_MODEL = MODELS_TO_TRY[0] # Start with flan-t5-base
ANSWER_MODEL = MODELS_TO_TRY[0] # Use same model for consistency
print(f"Loading models: {SUMMARY_MODEL}")
# Create pipelines with better parameters
try:
summarizer = pipeline(
"text2text-generation",
model=SUMMARY_MODEL,
device=-1, # CPU
model_kwargs={
"torch_dtype": torch.float32,
"low_cpu_mem_usage": True
}
)
answerer = pipeline(
"text2text-generation",
model=ANSWER_MODEL,
device=-1, # CPU
model_kwargs={
"torch_dtype": torch.float32,
"low_cpu_mem_usage": True
}
)
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
raise
# Keep last search context in memory so RAG can use previous search if user doesn't provide a new search
_last_combined_context = ""
_last_search_query = ""
def create_template_answer(context: str, question: str) -> str:
"""Create a template-based answer when the model fails"""
print("DEBUG: Creating template-based answer")
# Extract key information from context
cases = context.split("[Case ")
case_info = []
for i, case in enumerate(cases[1:], 1): # Skip first empty split
try:
# Extract case type
case_type_match = case.split("मुद्दाको किसिम: ")[1].split(" |")[0] if "मुद्दाको किसिम: " in case else "N/A"
# Extract subject
subject_match = case.split("विषय: ")[1].split(" |")[0] if "विषय: " in case else "N/A"
case_info.append({
'number': i,
'type': case_type_match,
'subject': subject_match,
'snippet': case[:300] + "..." if len(case) > 300 else case
})
except:
continue
# Generate answer based on question keywords
question_lower = question.lower()
# Check what the question is about
if any(word in question_lower for word in ['हक', 'अधिकार', 'कायम']):
answer = f"तपाईंको प्रश्न 'हक कायम' संबंधी छ। उपलब्ध {len(case_info)} केसहरूमा:\n\n"
elif any(word in question_lower for word in ['फैसला', 'बदर', 'निर्णय']):
answer = f"तपाईंको प्रश्न 'फैसला बदर' संबंधी छ। उपलब्ध {len(case_info)} केसहरूमा:\n\n"
elif any(word in question_lower for word in ['अंश', 'दर्ता', 'बाँडफाँड']):
answer = f"तपाईंको प्रश्न 'अंश दर्ता' संबंधी छ। उपलब्ध {len(case_info)} केसहरूमा:\n\n"
else:
answer = f"तपाईंको प्रश्नको सम्बन्धमा {len(case_info)} केसहरू भेटिएका छन्:\n\n"
# Add case details
for case in case_info[:3]: # Limit to top 3 cases
answer += f"केस {case['number']}: {case['type']} - {case['subject']}\n"
answer += f"मुख्य विषय: {case['subject']}\n\n"
answer += "विस्तृत जानकारीका लागि माथिका केसहरूको लिङ्कहरू हेर्नुहोस्।"
return answer
def semantic_search_ui(search_text: str):
"""Runs semantic search and returns formatted results. Also stores summarized context for RAG."""
global _last_combined_context, _last_search_query
print(f"DEBUG: Starting semantic search for: {search_text}")
try:
formatted, top_docs, combined_context = semantic_search(search_text, n_results=3)
print(f"DEBUG: Retrieved {len(top_docs)} documents")
# Skip model-based summarization for now - use direct text extraction instead
summaries = []
for idx, item in enumerate(top_docs, start=1):
doc_text = item["document"]
meta = item["metadata"]
print(f"DEBUG: Processing document {idx}, length: {len(doc_text)}")
# Create a manual summary using metadata and document text
# This is more reliable than model-based summarization
summary_parts = []
# Add key metadata
if meta.get('mudda_type'):
summary_parts.append(f"मुद्दाको किसिम: {meta['mudda_type']}")
if meta.get('subject'):
summary_parts.append(f"विषय: {meta['subject']}")
if meta.get('nibedak'):
summary_parts.append(f"निवेदक: {meta['nibedak'][:100]}...")
if meta.get('vipakshi'):
summary_parts.append(f"विपक्षी: {meta['vipakshi'][:100]}...")
# Add relevant text snippets (look for key legal terms)
doc_clean = doc_text.replace('["', '').replace('"]', '').replace('\\n', ' ')
# Extract sentences that contain important legal terms
important_sentences = []
sentences = doc_clean.split('।') # Split by Nepali sentence delimiter
for sentence in sentences[:5]: # Take first 5 sentences
sentence = sentence.strip()
if len(sentence) > 20 and any(term in sentence.lower() for term in ['फैसला', 'ठहर', 'अदालत', 'मुद्दा', 'कानुन']):
important_sentences.append(sentence[:200]) # Limit sentence length
if important_sentences:
summary_parts.append("मुख्य बुँदाहरू: " + "। ".join(important_sentences[:2]) + "।")
else:
# Fallback to first part of document
clean_start = doc_clean[:300].strip()
if clean_start:
summary_parts.append(f"विवरण: {clean_start}...")
# Combine all parts
manual_summary = " | ".join(summary_parts)
summaries.append(manual_summary)
print(f"DEBUG: Created manual summary {idx}: {manual_summary[:100]}...")
# Build compact combined context for the answerer (limited length)
compact_context = build_compact_context(summaries)
print(f"DEBUG: Built compact context, length: {len(compact_context)}")
print(f"DEBUG: Context preview: {compact_context[:200]}...")
# Save last context for Ask flow
_last_combined_context = compact_context
_last_search_query = search_text
return formatted, compact_context
except Exception as e:
error_msg = f"Error in semantic search: {e}"
print(f"DEBUG: {error_msg}")
return error_msg, ""
def rag_answer(question: str, search_text_for_context: str = ""):
"""
Answer the user's question using RAG:
- If search_text_for_context provided, run semantic search for it and use its summaries.
- Otherwise, use the last search context stored in memory (_last_combined_context).
"""
global _last_combined_context, _last_search_query
print(f"DEBUG: RAG answer called with question: {question[:50]}...")
start_time = time.time()
# If user provided a search string in the RAG tab, refresh context
if search_text_for_context and search_text_for_context.strip():
print("DEBUG: Refreshing context with new search")
_, compact_context = semantic_search_ui(search_text_for_context)
context = compact_context
else:
context = _last_combined_context
print(f"DEBUG: Using cached context, length: {len(context)}")
if not context or len(context.strip()) < 50:
return "No sufficient context available. Please run a semantic search first or provide a search query."
print(f"DEBUG: Using context: {context[:300]}...")
# Construct a simpler prompt that works better with the models
if "flan-t5" in ANSWER_MODEL.lower():
prompt = f"Based on these Nepali legal case summaries, answer the question in Nepali:\n\nContext: {context[:1500]}\n\nQuestion: {question}\n\nProvide a detailed answer in Nepali:"
else:
prompt = (
"तलका नेपाली अदालती मुद्दाका विवरणहरू प्रयोग गरेर प्रश्नको जवाफ नेपालीमा दिनुहोस्:\n\n"
f"मुद्दाहरूको विवरण:\n{context[:1500]}\n\n"
f"प्रश्न: {question}\n\n"
"विस्तृत जवाफ:"
)
print(f"DEBUG: Generated prompt length: {len(prompt)}")
print(f"DEBUG: Prompt preview: {prompt[:200]}...")
try:
print(f"DEBUG: Sending prompt to model (length: {len(prompt)})")
print(f"DEBUG: First 300 chars of prompt: {prompt[:300]}")
# Try multiple generation strategies
result = None
# Strategy 1: Simple generation
try:
result = answerer(
prompt,
max_length=300,
min_length=10,
do_sample=False,
num_beams=1,
pad_token_id=answerer.tokenizer.pad_token_id if hasattr(answerer.tokenizer, 'pad_token_id') else answerer.tokenizer.eos_token_id
)
print(f"DEBUG: Strategy 1 successful")
except Exception as e1:
print(f"DEBUG: Strategy 1 failed: {e1}")
# Strategy 2: Even simpler generation
try:
result = answerer(prompt, max_length=200, do_sample=False)
print(f"DEBUG: Strategy 2 successful")
except Exception as e2:
print(f"DEBUG: Strategy 2 failed: {e2}")
# Strategy 3: Template-based fallback (no model)
print(f"DEBUG: Using template-based fallback")
template_answer = create_template_answer(context, question)
return template_answer + f"\n\n---\n(Generated using template fallback in {time.time() - start_time:.2f}s)"
if result:
out = result[0]["generated_text"].strip()
print(f"DEBUG: Raw model output: '{out}'")
print(f"DEBUG: Output length: {len(out)}")
# Clean up the output - remove the input prompt if it's repeated
if prompt in out:
out = out.replace(prompt, "").strip()
print(f"DEBUG: Cleaned output: '{out[:100]}...'")
# Check if output is meaningful
if not out or len(out) < 5 or out.lower() in ['none', 'n/a', '']:
print(f"DEBUG: Output too short or meaningless, using template fallback")
out = create_template_answer(context, question)
else:
print(f"DEBUG: No result from model, using template fallback")
out = create_template_answer(context, question)
except Exception as e:
print(f"DEBUG: All strategies failed: {e}")
out = create_template_answer(context, question)
elapsed = time.time() - start_time
footer = f"\n\n---\n(Generated in {elapsed:.2f}s using summaries of top-3 cases.)"
return out + footer
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# 📚 Semantic Search + RAG (auto-summarize top-3) — Nepali cases")
gr.Markdown("**Debug Info**: Using models: " + SUMMARY_MODEL)
with gr.Tab("🔍 Semantic Search"):
search_input = gr.Textbox(
label="Search for a case (use Nepali preferred)",
placeholder="मुद्दाको संक्षेप वा कीवर्ड टाइप गर्नुहोस्..."
)
search_button = gr.Button("Search")
search_results = gr.Markdown(label="Top 3 Similar Cases (formatted)")
context_preview = gr.Textbox(
label="Combined Summarized Context (for RAG)",
interactive=False,
max_lines=10
)
search_button.click(
fn=semantic_search_ui,
inputs=search_input,
outputs=[search_results, context_preview]
)
with gr.Tab("🤖 Ask a Question (RAG)"):
question_input = gr.Textbox(
label="Your question (Nepali)",
placeholder="यहाँ प्रश्न लेख्नुहोस्..."
)
optional_search_input = gr.Textbox(
label="Optional: Search query to refresh context",
placeholder="(Optional) provide a search query to refresh top-3 context"
)
ask_button = gr.Button("Get Answer")
rag_output = gr.Markdown(label="LLM Answer (based on summarized top-3)")
ask_button.click(
fn=rag_answer,
inputs=[question_input, optional_search_input],
outputs=rag_output
)
with gr.Tab("🐛 Test Model"):
test_input = gr.Textbox(label="Test input", placeholder="Enter test text...")
test_button = gr.Button("Test Model")
test_output = gr.Textbox(label="Model output")
def test_model(text):
if not text.strip():
return "Please enter some text to test"
try:
# Test 1: Very simple prompt
simple_result = answerer(f"Translate to English: {text}", max_length=50, do_sample=False)
result1 = simple_result[0]["generated_text"]
# Test 2: Nepali prompt
nepali_result = answerer(f"यसलाई नेपालीमा भन्नुहोस्: {text}", max_length=50, do_sample=False)
result2 = nepali_result[0]["generated_text"]
return f"English test: {result1}\n\nNepali test: {result2}\n\nModel is working!"
except Exception as e:
return f"Model test failed: {e}\n\nModel details:\n- Name: {ANSWER_MODEL}\n- Type: {type(answerer)}"
test_button.click(fn=test_model, inputs=test_input, outputs=test_output)
gr.Markdown("""
**Notes**:
- The system summarizes the top-3 semantic results and uses those summaries as context for the LLM
- If you experience issues, try the Test Model tab first
- Check the console logs for debugging information
""")
if __name__ == "__main__":
demo.launch(debug=True)