|
|
|
|
|
import gradio as gr |
|
|
from chromadb_semantic_search_for_dataset import semantic_search, build_compact_context |
|
|
from transformers import pipeline, AutoTokenizer, MT5ForConditionalGeneration |
|
|
import time |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
MODELS_TO_TRY = [ |
|
|
"google/flan-t5-base", |
|
|
"google/mt5-base", |
|
|
|
|
|
] |
|
|
|
|
|
SUMMARY_MODEL = MODELS_TO_TRY[0] |
|
|
ANSWER_MODEL = MODELS_TO_TRY[0] |
|
|
|
|
|
print(f"Loading models: {SUMMARY_MODEL}") |
|
|
|
|
|
|
|
|
try: |
|
|
summarizer = pipeline( |
|
|
"text2text-generation", |
|
|
model=SUMMARY_MODEL, |
|
|
device=-1, |
|
|
model_kwargs={ |
|
|
"torch_dtype": torch.float32, |
|
|
"low_cpu_mem_usage": True |
|
|
} |
|
|
) |
|
|
|
|
|
answerer = pipeline( |
|
|
"text2text-generation", |
|
|
model=ANSWER_MODEL, |
|
|
device=-1, |
|
|
model_kwargs={ |
|
|
"torch_dtype": torch.float32, |
|
|
"low_cpu_mem_usage": True |
|
|
} |
|
|
) |
|
|
print("Models loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading models: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
_last_combined_context = "" |
|
|
_last_search_query = "" |
|
|
|
|
|
def create_template_answer(context: str, question: str) -> str: |
|
|
"""Create a template-based answer when the model fails""" |
|
|
print("DEBUG: Creating template-based answer") |
|
|
|
|
|
|
|
|
cases = context.split("[Case ") |
|
|
case_info = [] |
|
|
|
|
|
for i, case in enumerate(cases[1:], 1): |
|
|
try: |
|
|
|
|
|
case_type_match = case.split("मुद्दाको किसिम: ")[1].split(" |")[0] if "मुद्दाको किसिम: " in case else "N/A" |
|
|
|
|
|
subject_match = case.split("विषय: ")[1].split(" |")[0] if "विषय: " in case else "N/A" |
|
|
|
|
|
case_info.append({ |
|
|
'number': i, |
|
|
'type': case_type_match, |
|
|
'subject': subject_match, |
|
|
'snippet': case[:300] + "..." if len(case) > 300 else case |
|
|
}) |
|
|
except: |
|
|
continue |
|
|
|
|
|
|
|
|
question_lower = question.lower() |
|
|
|
|
|
|
|
|
if any(word in question_lower for word in ['हक', 'अधिकार', 'कायम']): |
|
|
answer = f"तपाईंको प्रश्न 'हक कायम' संबंधी छ। उपलब्ध {len(case_info)} केसहरूमा:\n\n" |
|
|
elif any(word in question_lower for word in ['फैसला', 'बदर', 'निर्णय']): |
|
|
answer = f"तपाईंको प्रश्न 'फैसला बदर' संबंधी छ। उपलब्ध {len(case_info)} केसहरूमा:\n\n" |
|
|
elif any(word in question_lower for word in ['अंश', 'दर्ता', 'बाँडफाँड']): |
|
|
answer = f"तपाईंको प्रश्न 'अंश दर्ता' संबंधी छ। उपलब्ध {len(case_info)} केसहरूमा:\n\n" |
|
|
else: |
|
|
answer = f"तपाईंको प्रश्नको सम्बन्धमा {len(case_info)} केसहरू भेटिएका छन्:\n\n" |
|
|
|
|
|
|
|
|
for case in case_info[:3]: |
|
|
answer += f"केस {case['number']}: {case['type']} - {case['subject']}\n" |
|
|
answer += f"मुख्य विषय: {case['subject']}\n\n" |
|
|
|
|
|
answer += "विस्तृत जानकारीका लागि माथिका केसहरूको लिङ्कहरू हेर्नुहोस्।" |
|
|
|
|
|
return answer |
|
|
|
|
|
def semantic_search_ui(search_text: str): |
|
|
"""Runs semantic search and returns formatted results. Also stores summarized context for RAG.""" |
|
|
global _last_combined_context, _last_search_query |
|
|
|
|
|
print(f"DEBUG: Starting semantic search for: {search_text}") |
|
|
|
|
|
try: |
|
|
formatted, top_docs, combined_context = semantic_search(search_text, n_results=3) |
|
|
print(f"DEBUG: Retrieved {len(top_docs)} documents") |
|
|
|
|
|
|
|
|
summaries = [] |
|
|
for idx, item in enumerate(top_docs, start=1): |
|
|
doc_text = item["document"] |
|
|
meta = item["metadata"] |
|
|
print(f"DEBUG: Processing document {idx}, length: {len(doc_text)}") |
|
|
|
|
|
|
|
|
|
|
|
summary_parts = [] |
|
|
|
|
|
|
|
|
if meta.get('mudda_type'): |
|
|
summary_parts.append(f"मुद्दाको किसिम: {meta['mudda_type']}") |
|
|
if meta.get('subject'): |
|
|
summary_parts.append(f"विषय: {meta['subject']}") |
|
|
if meta.get('nibedak'): |
|
|
summary_parts.append(f"निवेदक: {meta['nibedak'][:100]}...") |
|
|
if meta.get('vipakshi'): |
|
|
summary_parts.append(f"विपक्षी: {meta['vipakshi'][:100]}...") |
|
|
|
|
|
|
|
|
doc_clean = doc_text.replace('["', '').replace('"]', '').replace('\\n', ' ') |
|
|
|
|
|
|
|
|
important_sentences = [] |
|
|
sentences = doc_clean.split('।') |
|
|
|
|
|
for sentence in sentences[:5]: |
|
|
sentence = sentence.strip() |
|
|
if len(sentence) > 20 and any(term in sentence.lower() for term in ['फैसला', 'ठहर', 'अदालत', 'मुद्दा', 'कानुन']): |
|
|
important_sentences.append(sentence[:200]) |
|
|
|
|
|
if important_sentences: |
|
|
summary_parts.append("मुख्य बुँदाहरू: " + "। ".join(important_sentences[:2]) + "।") |
|
|
else: |
|
|
|
|
|
clean_start = doc_clean[:300].strip() |
|
|
if clean_start: |
|
|
summary_parts.append(f"विवरण: {clean_start}...") |
|
|
|
|
|
|
|
|
manual_summary = " | ".join(summary_parts) |
|
|
summaries.append(manual_summary) |
|
|
print(f"DEBUG: Created manual summary {idx}: {manual_summary[:100]}...") |
|
|
|
|
|
|
|
|
compact_context = build_compact_context(summaries) |
|
|
print(f"DEBUG: Built compact context, length: {len(compact_context)}") |
|
|
print(f"DEBUG: Context preview: {compact_context[:200]}...") |
|
|
|
|
|
|
|
|
_last_combined_context = compact_context |
|
|
_last_search_query = search_text |
|
|
|
|
|
return formatted, compact_context |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error in semantic search: {e}" |
|
|
print(f"DEBUG: {error_msg}") |
|
|
return error_msg, "" |
|
|
|
|
|
def rag_answer(question: str, search_text_for_context: str = ""): |
|
|
""" |
|
|
Answer the user's question using RAG: |
|
|
- If search_text_for_context provided, run semantic search for it and use its summaries. |
|
|
- Otherwise, use the last search context stored in memory (_last_combined_context). |
|
|
""" |
|
|
global _last_combined_context, _last_search_query |
|
|
|
|
|
print(f"DEBUG: RAG answer called with question: {question[:50]}...") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if search_text_for_context and search_text_for_context.strip(): |
|
|
print("DEBUG: Refreshing context with new search") |
|
|
_, compact_context = semantic_search_ui(search_text_for_context) |
|
|
context = compact_context |
|
|
else: |
|
|
context = _last_combined_context |
|
|
print(f"DEBUG: Using cached context, length: {len(context)}") |
|
|
|
|
|
if not context or len(context.strip()) < 50: |
|
|
return "No sufficient context available. Please run a semantic search first or provide a search query." |
|
|
|
|
|
print(f"DEBUG: Using context: {context[:300]}...") |
|
|
|
|
|
|
|
|
if "flan-t5" in ANSWER_MODEL.lower(): |
|
|
prompt = f"Based on these Nepali legal case summaries, answer the question in Nepali:\n\nContext: {context[:1500]}\n\nQuestion: {question}\n\nProvide a detailed answer in Nepali:" |
|
|
else: |
|
|
prompt = ( |
|
|
"तलका नेपाली अदालती मुद्दाका विवरणहरू प्रयोग गरेर प्रश्नको जवाफ नेपालीमा दिनुहोस्:\n\n" |
|
|
f"मुद्दाहरूको विवरण:\n{context[:1500]}\n\n" |
|
|
f"प्रश्न: {question}\n\n" |
|
|
"विस्तृत जवाफ:" |
|
|
) |
|
|
|
|
|
print(f"DEBUG: Generated prompt length: {len(prompt)}") |
|
|
print(f"DEBUG: Prompt preview: {prompt[:200]}...") |
|
|
|
|
|
try: |
|
|
print(f"DEBUG: Sending prompt to model (length: {len(prompt)})") |
|
|
print(f"DEBUG: First 300 chars of prompt: {prompt[:300]}") |
|
|
|
|
|
|
|
|
result = None |
|
|
|
|
|
|
|
|
try: |
|
|
result = answerer( |
|
|
prompt, |
|
|
max_length=300, |
|
|
min_length=10, |
|
|
do_sample=False, |
|
|
num_beams=1, |
|
|
pad_token_id=answerer.tokenizer.pad_token_id if hasattr(answerer.tokenizer, 'pad_token_id') else answerer.tokenizer.eos_token_id |
|
|
) |
|
|
print(f"DEBUG: Strategy 1 successful") |
|
|
except Exception as e1: |
|
|
print(f"DEBUG: Strategy 1 failed: {e1}") |
|
|
|
|
|
|
|
|
try: |
|
|
result = answerer(prompt, max_length=200, do_sample=False) |
|
|
print(f"DEBUG: Strategy 2 successful") |
|
|
except Exception as e2: |
|
|
print(f"DEBUG: Strategy 2 failed: {e2}") |
|
|
|
|
|
|
|
|
print(f"DEBUG: Using template-based fallback") |
|
|
template_answer = create_template_answer(context, question) |
|
|
return template_answer + f"\n\n---\n(Generated using template fallback in {time.time() - start_time:.2f}s)" |
|
|
|
|
|
if result: |
|
|
out = result[0]["generated_text"].strip() |
|
|
print(f"DEBUG: Raw model output: '{out}'") |
|
|
print(f"DEBUG: Output length: {len(out)}") |
|
|
|
|
|
|
|
|
if prompt in out: |
|
|
out = out.replace(prompt, "").strip() |
|
|
print(f"DEBUG: Cleaned output: '{out[:100]}...'") |
|
|
|
|
|
|
|
|
if not out or len(out) < 5 or out.lower() in ['none', 'n/a', '']: |
|
|
print(f"DEBUG: Output too short or meaningless, using template fallback") |
|
|
out = create_template_answer(context, question) |
|
|
|
|
|
else: |
|
|
print(f"DEBUG: No result from model, using template fallback") |
|
|
out = create_template_answer(context, question) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"DEBUG: All strategies failed: {e}") |
|
|
out = create_template_answer(context, question) |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
footer = f"\n\n---\n(Generated in {elapsed:.2f}s using summaries of top-3 cases.)" |
|
|
return out + footer |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# 📚 Semantic Search + RAG (auto-summarize top-3) — Nepali cases") |
|
|
gr.Markdown("**Debug Info**: Using models: " + SUMMARY_MODEL) |
|
|
|
|
|
with gr.Tab("🔍 Semantic Search"): |
|
|
search_input = gr.Textbox( |
|
|
label="Search for a case (use Nepali preferred)", |
|
|
placeholder="मुद्दाको संक्षेप वा कीवर्ड टाइप गर्नुहोस्..." |
|
|
) |
|
|
search_button = gr.Button("Search") |
|
|
search_results = gr.Markdown(label="Top 3 Similar Cases (formatted)") |
|
|
context_preview = gr.Textbox( |
|
|
label="Combined Summarized Context (for RAG)", |
|
|
interactive=False, |
|
|
max_lines=10 |
|
|
) |
|
|
|
|
|
search_button.click( |
|
|
fn=semantic_search_ui, |
|
|
inputs=search_input, |
|
|
outputs=[search_results, context_preview] |
|
|
) |
|
|
|
|
|
with gr.Tab("🤖 Ask a Question (RAG)"): |
|
|
question_input = gr.Textbox( |
|
|
label="Your question (Nepali)", |
|
|
placeholder="यहाँ प्रश्न लेख्नुहोस्..." |
|
|
) |
|
|
optional_search_input = gr.Textbox( |
|
|
label="Optional: Search query to refresh context", |
|
|
placeholder="(Optional) provide a search query to refresh top-3 context" |
|
|
) |
|
|
ask_button = gr.Button("Get Answer") |
|
|
rag_output = gr.Markdown(label="LLM Answer (based on summarized top-3)") |
|
|
|
|
|
ask_button.click( |
|
|
fn=rag_answer, |
|
|
inputs=[question_input, optional_search_input], |
|
|
outputs=rag_output |
|
|
) |
|
|
|
|
|
with gr.Tab("🐛 Test Model"): |
|
|
test_input = gr.Textbox(label="Test input", placeholder="Enter test text...") |
|
|
test_button = gr.Button("Test Model") |
|
|
test_output = gr.Textbox(label="Model output") |
|
|
|
|
|
def test_model(text): |
|
|
if not text.strip(): |
|
|
return "Please enter some text to test" |
|
|
|
|
|
try: |
|
|
|
|
|
simple_result = answerer(f"Translate to English: {text}", max_length=50, do_sample=False) |
|
|
result1 = simple_result[0]["generated_text"] |
|
|
|
|
|
|
|
|
nepali_result = answerer(f"यसलाई नेपालीमा भन्नुहोस्: {text}", max_length=50, do_sample=False) |
|
|
result2 = nepali_result[0]["generated_text"] |
|
|
|
|
|
return f"English test: {result1}\n\nNepali test: {result2}\n\nModel is working!" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Model test failed: {e}\n\nModel details:\n- Name: {ANSWER_MODEL}\n- Type: {type(answerer)}" |
|
|
|
|
|
test_button.click(fn=test_model, inputs=test_input, outputs=test_output) |
|
|
|
|
|
gr.Markdown(""" |
|
|
**Notes**: |
|
|
- The system summarizes the top-3 semantic results and uses those summaries as context for the LLM |
|
|
- If you experience issues, try the Test Model tab first |
|
|
- Check the console logs for debugging information |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(debug=True) |