Spaces:
Running
Running
| import gradio as gr | |
| import requests | |
| import numpy as np | |
| import time | |
| import json | |
| import os | |
| # Import the utilities with proper error handling | |
| try: | |
| from utils.encoding_input import encode_text | |
| from utils.retrieve_n_rerank import retrieve_and_rerank | |
| from utils.sentiment_analysis import get_sentiment | |
| from utils.coherence_bbscore import coherence_report | |
| from utils.loading_embeddings import get_vectorstore | |
| from utils.model_generation import build_messages | |
| except ImportError as e: | |
| print(f"Import error: {e}") | |
| print("Make sure you're running from the correct directory and all dependencies are installed.") | |
| API_KEY = os.getenv("API_KEY", "sk-do-8Hjf0liuGQCoPwglilL49xiqrthMECwjGP_kAjPM53OTOFQczPyfPK8xJc") | |
| MODEL = "llama3.3-70b-instruct" | |
| # Global settings for sentiment and coherence analysis | |
| ENABLE_SENTIMENT = True | |
| ENABLE_COHERENCE = True | |
| def chat_response(message, history): | |
| """ | |
| Generate response for chat interface. | |
| Args: | |
| message: Current user message | |
| history: List of [user_message, bot_response] pairs | |
| """ | |
| try: | |
| # Initialize vectorstore when needed | |
| vectorstore = get_vectorstore() | |
| # Retrieve and rerank documents | |
| reranked_results = retrieve_and_rerank( | |
| query_text=message, | |
| vectorstore=vectorstore, | |
| k=50, # number of initial documents to retrieve | |
| rerank_model="cross-encoder/ms-marco-MiniLM-L-6-v2", | |
| top_m=20, # number of documents to return after reranking | |
| min_score=0.5, # minimum score for reranked documents | |
| only_docs=False # return both documents and scores | |
| ) | |
| if not reranked_results: | |
| return "I'm sorry, I couldn't find any relevant information in the policy documents to answer your question. Could you try rephrasing your question or asking about a different topic?" | |
| top_docs = [doc for doc, score in reranked_results] | |
| # Perform sentiment and coherence analysis if enabled | |
| sentiment_rollup = get_sentiment(top_docs) if ENABLE_SENTIMENT else {} | |
| coherence_report_ = coherence_report(reranked_results=top_docs, input_text=message) if ENABLE_COHERENCE else "" | |
| # Build messages for the LLM, including conversation history | |
| messages = build_messages_with_history( | |
| query=message, | |
| history=history, | |
| top_docs=top_docs, | |
| task_mode="verbatim_sentiment", | |
| sentiment_rollup=sentiment_rollup, | |
| coherence_report=coherence_report_, | |
| ) | |
| # Stream response from the API | |
| response = "" | |
| for chunk in stream_llm_response(messages): | |
| response += chunk | |
| yield response | |
| except Exception as e: | |
| error_msg = f"I encountered an error while processing your request: {str(e)}" | |
| yield error_msg | |
| def build_messages_with_history(query, history, top_docs, task_mode, sentiment_rollup, coherence_report): | |
| """Build messages including conversation history for better context.""" | |
| # System message | |
| system_msg = ( | |
| "You are a compliance-grade policy analyst assistant specializing in Kenya policy documents. " | |
| "Your job is to return precise, fact-grounded responses based on the provided policy documents. " | |
| "Avoid hallucinations. Base everything strictly on the content provided. " | |
| "Maintain conversation context from previous exchanges when relevant. " | |
| "If sentiment or coherence analysis is not available, do not mention it in the response." | |
| ) | |
| messages = [{"role": "system", "content": system_msg}] | |
| # Add conversation history (keep last 4 exchanges to maintain context without exceeding limits) | |
| recent_history = history[-4:] if len(history) > 4 else history | |
| for user_msg, bot_msg in recent_history: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": bot_msg}) | |
| # Build context from retrieved documents | |
| context_block = "\n\n".join([ | |
| f"**Source: {getattr(doc, 'metadata', {}).get('source', 'Unknown')} " | |
| f"(Page {getattr(doc, 'metadata', {}).get('page', 'Unknown')})**\n" | |
| f"{doc.page_content}\n" | |
| for doc in top_docs[:10] # Limit to top 10 docs to avoid token limits | |
| ]) | |
| # Current user query with context | |
| current_query = f""" | |
| Query: {query} | |
| Based on the following policy documents, please provide: | |
| 1) **Quoted Policy Excerpts**: Quote key policy content directly. Cite the source using filename and page. | |
| 2) **Analysis**: Explain the policy implications in clear terms. | |
| """ | |
| if sentiment_rollup: | |
| current_query += f"\n3) **Sentiment Summary**: {sentiment_rollup}" | |
| if coherence_report: | |
| current_query += f"\n4) **Coherence Assessment**: {coherence_report}" | |
| current_query += f"\n\nContext Sources:\n{context_block}" | |
| messages.append({"role": "user", "content": current_query}) | |
| return messages | |
| def stream_llm_response(messages): | |
| """Stream response from the LLM API.""" | |
| headers = { | |
| "Authorization": f"Bearer {API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": MODEL, | |
| "messages": messages, | |
| "temperature": 0.2, | |
| "stream": True, | |
| "max_tokens": 2000 | |
| } | |
| try: | |
| with requests.post("https://inference.do-ai.run/v1/chat/completions", | |
| headers=headers, json=data, stream=True, timeout=30) as r: | |
| if r.status_code != 200: | |
| yield f"[ERROR] API returned status {r.status_code}: {r.text}" | |
| return | |
| for line in r.iter_lines(decode_unicode=True): | |
| if not line or line.strip() == "data: [DONE]": | |
| continue | |
| if line.startswith("data: "): | |
| line = line[len("data: "):] | |
| try: | |
| chunk = json.loads(line) | |
| delta = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "") | |
| if delta: | |
| yield delta | |
| time.sleep(0.01) # Small delay for smooth streaming | |
| except json.JSONDecodeError: | |
| continue | |
| except Exception as e: | |
| print(f"Streaming error: {e}") | |
| continue | |
| except requests.exceptions.RequestException as e: | |
| yield f"[ERROR] Network error: {str(e)}" | |
| except Exception as e: | |
| yield f"[ERROR] Unexpected error: {str(e)}" | |
| def update_sentiment_setting(enable): | |
| """Update global sentiment analysis setting.""" | |
| global ENABLE_SENTIMENT | |
| ENABLE_SENTIMENT = enable | |
| return f"β Sentiment analysis {'enabled' if enable else 'disabled'}" | |
| def update_coherence_setting(enable): | |
| """Update global coherence analysis setting.""" | |
| global ENABLE_COHERENCE | |
| ENABLE_COHERENCE = enable | |
| return f"β Coherence analysis {'enabled' if enable else 'disabled'}" | |
| # Create the chat interface | |
| with gr.Blocks(title="Kenya Policy Assistant - Chat", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # ποΈ Kenya Policy Assistant - Interactive Chat | |
| Ask questions about Kenya's policies and have a conversation! I can help you understand policy documents with sentiment and coherence analysis. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| # Settings row at the top | |
| with gr.Row(): | |
| sentiment_toggle = gr.Checkbox( | |
| label="π Sentiment Analysis", | |
| value=True, | |
| info="Analyze tone and sentiment of policy documents" | |
| ) | |
| coherence_toggle = gr.Checkbox( | |
| label="π Coherence Analysis", | |
| value=True, | |
| info="Check coherence and consistency of retrieved documents" | |
| ) | |
| # Main chat interface | |
| chatbot = gr.Chatbot( | |
| height=500, | |
| bubble_full_width=False, | |
| show_copy_button=True, | |
| show_share_button=True, | |
| avatar_images=("π€", "π€") | |
| ) | |
| msg = gr.Textbox( | |
| placeholder="Ask me about Kenya's policies... (e.g., 'What are the renewable energy regulations?')", | |
| label="Your Question", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("π€ Send", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear Chat") | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| ### π‘ Chat Tips | |
| - Ask specific questions about Kenya policies | |
| - Ask follow-up questions based on responses | |
| - Reference previous answers: *"What does this mean?"* | |
| - Request elaboration: *"Can you explain more?"* | |
| ### π Example Questions | |
| - *"What are Kenya's renewable energy policies?"* | |
| - *"Tell me about water management regulations"* | |
| - *"What penalties exist for environmental violations?"* | |
| - *"How does this relate to what you mentioned earlier?"* | |
| ### βοΈ Analysis Features | |
| **Sentiment Analysis**: Understands the tone and intent of policy text | |
| **Coherence Analysis**: Checks if retrieved documents are relevant and consistent | |
| """) | |
| with gr.Accordion("π Analysis Status", open=False): | |
| sentiment_status = gr.Textbox( | |
| value="β Sentiment analysis enabled", | |
| label="Sentiment Status", | |
| interactive=False | |
| ) | |
| coherence_status = gr.Textbox( | |
| value="β Coherence analysis enabled", | |
| label="Coherence Status", | |
| interactive=False | |
| ) | |
| # Chat functionality | |
| def respond(message, history): | |
| if message.strip(): | |
| bot_message = chat_response(message, history) | |
| history.append([message, ""]) | |
| for partial_response in bot_message: | |
| history[-1][1] = partial_response | |
| yield history, "" | |
| else: | |
| yield history, "" | |
| submit_btn.click(respond, [msg, chatbot], [chatbot, msg]) | |
| msg.submit(respond, [msg, chatbot], [chatbot, msg]) | |
| clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
| # Update settings when toggles change | |
| sentiment_toggle.change( | |
| fn=update_sentiment_setting, | |
| inputs=[sentiment_toggle], | |
| outputs=[sentiment_status] | |
| ) | |
| coherence_toggle.change( | |
| fn=update_coherence_setting, | |
| inputs=[coherence_toggle], | |
| outputs=[coherence_status] | |
| ) | |
| if __name__ == "__main__": | |
| print("π Starting Kenya Policy Assistant Chat...") | |
| demo.queue(max_size=20).launch( | |
| share=True, | |
| debug=True, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |