Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import re | |
| import openai | |
| from huggingface_hub import InferenceClient | |
| import json | |
| from huggingface_hub import HfApi | |
| import streamlit as st | |
| from typing import List, Dict, Any | |
| from urllib.parse import quote_plus | |
| from pymongo import MongoClient | |
| from PyPDF2 import PdfReader | |
| st.set_page_config(page_title="Grant Buddy RAG", page_icon="๐ค") | |
| from typing import List | |
| from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
| from langchain.prompts import PromptTemplate | |
| from langchain.schema import Document | |
| from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
| from huggingface_hub import InferenceClient | |
| # =================== Secure Env via Hugging Face Secrets =================== | |
| user = quote_plus(os.getenv("MONGO_USERNAME")) | |
| password = quote_plus(os.getenv("MONGO_PASSWORD")) | |
| cluster = os.getenv("MONGO_CLUSTER") | |
| db_name = os.getenv("MONGO_DB_NAME", "files") | |
| collection_name = os.getenv("MONGO_COLLECTION", "files_collection") | |
| index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip() | |
| if OPENAI_API_KEY: | |
| openai.api_key = OPENAI_API_KEY | |
| from openai import OpenAI | |
| client = OpenAI(api_key=OPENAI_API_KEY) | |
| # MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority" | |
| MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true" | |
| # =================== Prompt =================== | |
| grantbuddy_prompt = PromptTemplate.from_template( | |
| """You are Grant Buddy, a specialized language model fine-tuned with instruction-tuning and RLHF. | |
| You help a nonprofit focused on social entrepreneurship, BIPOC empowerment, and edtech write clear, mission-aligned grant responses. | |
| **Instructions:** | |
| - Start with reasoning or context for your answer. | |
| - Always align with the nonprofitโs mission. | |
| - Use structured formatting: headings, bullet points, numbered lists. | |
| - Include impact data or examples if relevant. | |
| - Do NOT repeat the same sentence or answer multiple times. | |
| - If no answer exists in the context, say: "This information is not available in the current context." | |
| CONTEXT: | |
| {context} | |
| QUESTION: | |
| {question} | |
| """ | |
| ) | |
| # =================== Vector Search Setup =================== | |
| def init_embedding_model(): | |
| return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| def init_vector_search() -> MongoDBAtlasVectorSearch: | |
| HF_TOKEN = os.getenv("HF_TOKEN", "").strip() | |
| model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
| st.write(f"๐ Connecting to Hugging Face model: `{model_name}`") | |
| embedding_model = HuggingFaceEmbeddings(model_name=model_name) | |
| # โ Manual MongoClient with TLS settings | |
| user = quote_plus(os.getenv("MONGO_USERNAME", "").strip()) | |
| password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip()) | |
| cluster = os.getenv("MONGO_CLUSTER", "").strip() | |
| db_name = os.getenv("MONGO_DB_NAME", "files").strip() | |
| collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip() | |
| index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip() | |
| mongo_uri = f"mongodb+srv://{user}:{password}@{cluster}/?retryWrites=true&w=majority" | |
| try: | |
| client = MongoClient(mongo_uri, tls=True, tlsAllowInvalidCertificates=True, serverSelectionTimeoutMS=20000) | |
| db = client[db_name] | |
| collection = db[collection_name] | |
| st.success("โ MongoClient connected successfully") | |
| return MongoDBAtlasVectorSearch( | |
| collection=collection, | |
| embedding=embedding_model, | |
| index_name=index_name, | |
| ) | |
| except Exception as e: | |
| st.error("โ Failed to connect to MongoDB Atlas manually") | |
| st.error(str(e)) | |
| raise e | |
| # =================== Question/Headers Extraction =================== | |
| # def extract_questions_and_headers(text: str) -> List[str]: | |
| # header_patterns = [ | |
| # r'\d+\.\s+\*\*([^\*]+)\*\*', | |
| # r'\*\*([^*]+)\*\*', | |
| # r'^([A-Z][^a-z]*[A-Z])$', | |
| # r'^([A-Z][A-Za-z\s]{3,})$', | |
| # r'^[A-Z][A-Za-z\s]+:$' | |
| # ] | |
| # question_patterns = [ | |
| # r'^.+\?$', | |
| # r'^\*?Please .+', | |
| # r'^How .+', | |
| # r'^What .+', | |
| # r'^Describe .+', | |
| # ] | |
| # combined_header_re = re.compile("|".join(header_patterns), re.MULTILINE) | |
| # combined_question_re = re.compile("|".join(question_patterns), re.MULTILINE) | |
| # headers = [match for group in combined_header_re.findall(text) for match in group if match] | |
| # questions = combined_question_re.findall(text) | |
| # return headers + questions | |
| # def extract_with_llm(text: str) -> List[str]: | |
| # client = InferenceClient(api_key=HF_TOKEN.strip()) | |
| # try: | |
| # response = client.chat.completions.create( | |
| # model="mistralai/Mistral-Nemo-Instruct-2407", # or "HuggingFaceH4/zephyr-7b-beta" | |
| # messages=[ | |
| # { | |
| # "role": "system", | |
| # "content": "You are an assistant helping extract questions and headers from grant applications.", | |
| # }, | |
| # { | |
| # "role": "user", | |
| # "content": ( | |
| # "Please extract all the grant application headers and questions from the following text. " | |
| # "Include section titles, prompts, and any question-like content. Return them as a numbered list.\n\n" | |
| # f"{text[:3000]}" | |
| # ), | |
| # }, | |
| # ], | |
| # temperature=0.2, | |
| # max_tokens=512, | |
| # ) | |
| # return [ | |
| # line.strip("โข-1234567890. ").strip() | |
| # for line in response.choices[0].message.content.strip().split("\n") | |
| # if line.strip() | |
| # ] | |
| # except Exception as e: | |
| # st.error("โ LLM extraction failed") | |
| # st.error(str(e)) | |
| # return [] | |
| # def extract_with_llm_local(text: str) -> List[str]: | |
| # prompt = ( | |
| # "You are an assistant helping extract useful questions and section headers from a grant application.\n" | |
| # "Return only the important prompts as a numbered list.\n\n" | |
| # "TEXT:\n" | |
| # f"{text[:3000]}\n\n" | |
| # "PROMPTS:" | |
| # ) | |
| # inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
| # outputs = model.generate( | |
| # **inputs, | |
| # max_new_tokens=512, | |
| # temperature=0.3, | |
| # do_sample=False | |
| # ) | |
| # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # # Extract prompts from the numbered list in the output | |
| # lines = raw_output.split("\n") | |
| # prompts = [] | |
| # for line in lines: | |
| # line = line.strip("โข-1234567890. ").strip() | |
| # if len(line) > 10: | |
| # prompts.append(line) | |
| # return prompts | |
| # def extract_with_llm_local(text: str) -> List[str]: | |
| # example_text = """TEXT: | |
| # 1. Project Summary: Please describe the main goals of your project. | |
| # 2. Contact Information: Address, phone, email. | |
| # 3. What is the mission of your organization? | |
| # 4. Who are the beneficiaries? | |
| # 5. Budget Breakdown | |
| # 6. Please describe how the funding will be used. | |
| # 7. Website: www.example.org | |
| # PROMPTS: | |
| # 1. Project Summary | |
| # 2. What is the mission of your organization? | |
| # 3. Who are the beneficiaries? | |
| # 4. Please describe how the funding will be used. | |
| # """ | |
| # prompt = ( | |
| # "You are an assistant helping extract important grant application prompts and section headers.\n" | |
| # "Return only questions and meaningful section titles that require thoughtful answers.\n" | |
| # "Avoid metadata like phone numbers, dates, contact info, or websites.\n\n" | |
| # f"{example_text}\n" | |
| # f"TEXT:\n{text[:3000]}\n\n" | |
| # "PROMPTS:" | |
| # ) | |
| # inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
| # outputs = model.generate( | |
| # **inputs, | |
| # max_new_tokens=512, | |
| # temperature=0.3, | |
| # do_sample=False | |
| # ) | |
| # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # # Clean and extract numbered or bulleted lines | |
| # lines = raw_output.split("\n") | |
| # prompts = [] | |
| # for line in lines: | |
| # clean = line.strip("โข-1234567890. ").strip() | |
| # if len(clean) > 10 and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]): | |
| # prompts.append(clean) | |
| # return prompts | |
| def extract_with_llm_local(text: str, use_openai: bool = False) -> List[str]: | |
| # Example context to prime the model | |
| example_text = """TEXT: | |
| 1. Project Summary: Please describe the main goals of your project. | |
| 2. Contact Information: Address, phone, email. | |
| 3. What is the mission of your organization? | |
| 4. Who are the beneficiaries? | |
| 5. Budget Breakdown | |
| 6. Please describe how the funding will be used. | |
| 7. Website: www.example.org | |
| PROMPTS: | |
| 1. Project Summary | |
| 2. What is the mission of your organization? | |
| 3. Who are the beneficiaries? | |
| 4. Please describe how the funding will be used. | |
| """ | |
| prompt = ( | |
| "You are an assistant helping extract important grant application prompts and section headers.\n" | |
| "Return only questions and meaningful section titles that require thoughtful answers.\n" | |
| "Avoid metadata like phone numbers, dates, contact info, or websites.\n\n" | |
| f"{example_text}\n" | |
| f"TEXT:\n{text[:3000]}\n\n" | |
| "PROMPTS:" | |
| ) | |
| if use_openai: | |
| if not openai.api_key: | |
| st.error("โ OPENAI_API_KEY is not set.") | |
| return "โ ๏ธ OpenAI key missing." | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": "You extract prompts and headers from grant text."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.2, | |
| max_tokens=500, | |
| ) | |
| # raw_output = response["choices"][0]["message"]["content"] | |
| raw_output = response.choices[0].message.content | |
| st.markdown(f"๐งฎ Extract Tokens: Prompt = {response.usage.prompt_tokens}, " | |
| f"Completion = {response.usage.completion_tokens}, Total = {response.usage.total_tokens}") | |
| except Exception as e: | |
| st.error(f"โ OpenAI extraction failed: {e}") | |
| return [] | |
| else: | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| temperature=0.3, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean and deduplicate prompts | |
| lines = raw_output.split("\n") | |
| prompts = [] | |
| seen = set() | |
| for line in lines: | |
| clean = line.strip("โข-1234567890. ").strip() | |
| if ( | |
| len(clean) > 10 | |
| and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]) | |
| and clean not in seen | |
| ): | |
| prompts.append(clean) | |
| seen.add(clean) | |
| return prompts | |
| # def is_meaningful_prompt(text: str) -> bool: | |
| # too_short = len(text.strip()) < 10 | |
| # banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"] | |
| # contains_bad_word = any(word in text.lower() for word in banned_keywords) | |
| # is_just_punctuation = all(c in ":.*- " for c in text.strip()) | |
| # return not (too_short or contains_bad_word or is_just_punctuation) | |
| # =================== Format Retrieved Chunks =================== | |
| def format_docs(docs: List[Document]) -> str: | |
| return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs) | |
| # =================== Generate Response from Hugging Face Model =================== | |
| # def generate_response(input_dict: Dict[str, Any]) -> str: | |
| # client = InferenceClient(api_key=HF_TOKEN.strip()) | |
| # prompt = grantbuddy_prompt.format(**input_dict) | |
| # try: | |
| # response = client.chat.completions.create( | |
| # model="HuggingFaceH4/zephyr-7b-beta", | |
| # messages=[ | |
| # {"role": "system", "content": prompt}, | |
| # {"role": "user", "content": input_dict["question"]}, | |
| # ], | |
| # max_tokens=1000, | |
| # temperature=0.2, | |
| # ) | |
| # return response.choices[0].message.content | |
| # except Exception as e: | |
| # st.error(f"โ Error from model: {e}") | |
| # return "โ ๏ธ Failed to generate response. Please check your model, HF token, or request format." | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| def load_local_model(): | |
| model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| return tokenizer, model | |
| tokenizer, model = load_local_model() | |
| def generate_response(input_dict, use_openai=False, max_tokens=700): | |
| prompt = grantbuddy_prompt.format(**input_dict) | |
| if use_openai: | |
| try: | |
| response = client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": prompt}, | |
| {"role": "user", "content": input_dict["question"]}, | |
| ], | |
| temperature=0.2, | |
| max_tokens=max_tokens, | |
| ) | |
| answer = response.choices[0].message.content.strip() | |
| # โ Token logging | |
| prompt_tokens = response.usage.prompt_tokens | |
| completion_tokens = response.usage.completion_tokens | |
| total_tokens = response.usage.total_tokens | |
| return { | |
| "answer": answer, | |
| "tokens": { | |
| "prompt": prompt_tokens, | |
| "completion": completion_tokens, | |
| "total": total_tokens | |
| } | |
| } | |
| except Exception as e: | |
| st.error(f"โ OpenAI error: {e}") | |
| return { | |
| "answer": "โ ๏ธ OpenAI request failed.", | |
| "tokens": {} | |
| } | |
| else: | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return { | |
| "answer": decoded[len(prompt):].strip(), | |
| "tokens": {} | |
| } | |
| # =================== RAG Chain =================== | |
| def get_rag_chain(retriever, use_openai=False, max_tokens=700): | |
| def merge_contexts(inputs): | |
| #use chunks if provided | |
| retrieved_chunks = format_docs(inputs["context_docs"]) if "context_docs" in inputs \ | |
| else format_docs(retriever.invoke(inputs["question"])) | |
| combined = "\n\n".join(filter(None, [ | |
| inputs.get("manual_context", ""), | |
| retrieved_chunks | |
| ])) | |
| return { | |
| "context": combined, | |
| "question": inputs["question"] | |
| } | |
| return RunnableLambda(merge_contexts) | RunnableLambda( | |
| lambda input_dict: generate_response(input_dict, use_openai=use_openai, max_tokens=max_tokens) | |
| ) | |
| def rerank_with_topics(chunks, topics, alpha=0.2): | |
| """ | |
| Boosts similarity based on topic overlap. | |
| Since chunks don't have scores, we use rank order and topic matches. | |
| """ | |
| topics_lower = set(t.lower() for t in topics) | |
| def score(chunk, rank): | |
| chunk_topics = [t.lower() for t in chunk.metadata.get("topics", [])] | |
| topic_matches = len(topics_lower.intersection(chunk_topics)) | |
| # Lower is better: original rank minus boost | |
| return rank - alpha * topic_matches | |
| reranked = sorted( | |
| enumerate(chunks), | |
| key=lambda x: score(x[1], x[0]) # x[0] is rank, x[1] is chunk | |
| ) | |
| return [chunk for _, chunk in reranked] | |
| # =================== Streamlit UI =================== | |
| def gate_ui(): | |
| # Get password from secrets (optional env fallback) | |
| APP_PASSWORD = st.secrets.get("APP_PASSWORD", os.getenv("APP_PASSWORD", "")).strip() | |
| # Persist auth in session | |
| if "authed" not in st.session_state: | |
| st.session_state.authed = False | |
| if st.session_state.authed: | |
| return True | |
| st.title("๐ Grant Buddy Login") | |
| pwd = st.text_input("Enter password", type="password") | |
| col1, col2 = st.columns([1,1]) | |
| with col1: | |
| if st.button("Login"): | |
| if APP_PASSWORD and pwd == APP_PASSWORD: | |
| st.session_state.authed = True | |
| try: | |
| st.rerun() | |
| except AttributeError: | |
| st.experimental_rerun() | |
| else: | |
| st.error("Incorrect password.") | |
| with col2: | |
| if st.button("Forgot?"): | |
| st.info("Contact the admin to reset the APP_PASSWORD secret.") | |
| return False | |
| # ===== RETRIEVAL SETTINGS SIDEBAR ===== | |
| def retrieval_settings(): | |
| st.sidebar.header("Retrieval Settings") | |
| k_value = st.sidebar.slider("How many chunks to retrieve (k)", 5, 40, 10, 1) | |
| # min_score = st.sidebar.slider("Minimum relevance score", 0.0, 1.0, 0.75, 0.01) | |
| topics_str = st.sidebar.text_input("Optional: Focus on specific topics (comma-separated)", "") | |
| topic_score = st.sidebar.slider("Topic relevance score", 0.0, 1.0, 0.0, 0.01) | |
| st.sidebar.header("Generation Settings") | |
| max_tokens = st.sidebar.number_input("Max tokens in response", 100, 1500, 700, 50) | |
| use_openai = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False) | |
| topics = [t.strip() for t in topics_str.split(",") if t.strip()] | |
| return { | |
| "k": k_value, | |
| "topics": topics, | |
| "topic_score": topic_score, | |
| "max_tokens": max_tokens, | |
| "use_openai": use_openai | |
| } | |
| def parse_topics_field(val): | |
| """Parses a topics field into a flat list of topics, splitting on commas and underscores.""" | |
| if not val: | |
| return [] | |
| # If it's already a list, normalize each entry | |
| if isinstance(val, list): | |
| topics = [str(t).strip() for t in val if str(t).strip()] | |
| else: | |
| # Split first by commas | |
| topics = [t.strip() for t in str(val).split(",") if t.strip()] | |
| # Now split each topic by underscores into subtopics | |
| topics = [sub for topic in topics for sub in topic.split("_") if sub.strip()] | |
| return topics | |
| def resolve_item_settings(item, defaults): | |
| """ | |
| Merge per-item overrides with UI defaults. | |
| item: dict from JSON | |
| defaults: dict from sidebar (k, max_tokens, use_openai, topics, topic_score) | |
| """ | |
| return { | |
| "use_openai": bool(item.get("use_openai", defaults["use_openai"])), | |
| "k": int(item.get("k", defaults["k"])), | |
| "max_tokens": int(item.get("max_tokens", defaults["max_tokens"])), | |
| # rerank controls: | |
| "topics": parse_topics_field(item.get("topics", defaults.get("topics", []))), | |
| "topic_score": float(item.get("topic_weight", defaults.get("topic_score", 0.0))), | |
| "optional_context": item.get("optional_context", defaults.get("optional_context", "")), | |
| } | |
| def run_query(query: str, | |
| manual_context: str, | |
| vectorstore, | |
| use_openai: bool, | |
| k: int = 10, | |
| topic_list: list[str] | None = None, | |
| topic_alpha: float = 0.2, | |
| max_tokens: int = 700): | |
| # Safety clamps | |
| k = max(1, min(int(k), 40)) | |
| topic_alpha = max(0.0, min(float(topic_alpha), 1.0)) | |
| # 1) Overfetch for quality | |
| pre_k = max(1, k * 4) | |
| docs_scores = vectorstore.similarity_search_with_score(query, k=pre_k) | |
| # 2) Soft filter + fallback | |
| min_score = 0.75 | |
| filtered = [(d, s) for d, s in docs_scores if s >= min_score] | |
| if len(filtered) < k: | |
| filtered = docs_scores[:k] # fallback to top-k regardless of score | |
| docs = [d for d, _ in filtered] | |
| # 3) Topic re-rank AFTER filtering | |
| if topic_list and topic_alpha > 0: | |
| docs = rerank_with_topics(docs, topic_list, alpha=topic_alpha) | |
| # 4) Dedupe + trim | |
| seen, final = set(), [] | |
| for d in docs: | |
| cid = (d.metadata or {}).get("chunk_id") or id(d) | |
| if cid in seen: | |
| continue | |
| seen.add(cid) | |
| final.append(d) | |
| docs = final[:k] | |
| # 5) Pass to RAG chain | |
| rag_chain = get_rag_chain(retriever=None, use_openai=use_openai, max_tokens=max_tokens) | |
| combined_manual = (manual_context or "").strip() | |
| out = rag_chain.invoke({ | |
| "question": query, | |
| "manual_context": combined_manual, | |
| "context_docs": docs | |
| }) | |
| return {"answer": out.get("answer", ""), "tokens": out.get("tokens", {}), "docs": docs} | |
| def show_chunks(docs): | |
| with st.expander("๐ Retrieved Chunks", expanded=False): | |
| for d in docs: | |
| meta_outer = d.metadata if isinstance(d.metadata, dict) else {} | |
| inner = meta_outer.get("metadata", {}) if isinstance(meta_outer.get("metadata", {}), dict) else {} | |
| title = inner.get("title", "unknown") | |
| chunk_id = meta_outer.get("chunk_id", "unknown") | |
| st.markdown(f"**Chunk ID:** {chunk_id} | **Title:** {title}") | |
| st.markdown((d.page_content or "")[:700] + "โฆ") | |
| st.markdown("---") | |
| # ===== MAIN ===== | |
| def main(): | |
| if not gate_ui(): | |
| return | |
| st.title("๐ค Grant Buddy โ Manual / JSON Mode") | |
| settings = retrieval_settings() | |
| manual_context = st.text_area("๐ Optional: Global context for this run (mission, RFP notes, etc.)", height=150) | |
| vectorstore = init_vector_search() | |
| tab_manual, tab_batch = st.tabs(["โ๏ธ Manual Mode", "๐งฉ JSON Batch Mode"]) | |
| # ---- Manual Mode ---- | |
| with tab_manual: | |
| st.subheader("Manual Query") | |
| uploaded_file = st.file_uploader("Upload PDF/TXT (optional)", type=["pdf", "txt"]) | |
| uploaded_text = "" | |
| if uploaded_file: | |
| if uploaded_file.name.endswith(".pdf"): | |
| reader = PdfReader(uploaded_file) | |
| uploaded_text = "\n".join(p.extract_text() for p in reader.pages if p.extract_text()) | |
| else: | |
| uploaded_text = uploaded_file.read().decode("utf-8") | |
| # Combine uploaded text with the global manual context | |
| combined_manual_context = "\n\n".join( | |
| s for s in [manual_context.strip(), uploaded_text.strip()] if s | |
| ) | |
| query = st.text_input("Enter your question") | |
| if st.button("Run Manual Query"): | |
| if not query: | |
| st.warning("Please enter a question.") | |
| else: | |
| # show current settings being used | |
| st.write("### Retrieval/Generation Settings Used:") | |
| st.json(settings) | |
| result = run_query( | |
| query=query, | |
| manual_context=combined_manual_context, | |
| vectorstore=vectorstore, | |
| use_openai=settings["use_openai"], | |
| k=settings["k"], | |
| topic_list=settings["topics"], | |
| topic_alpha=settings["topic_score"], | |
| max_tokens=settings["max_tokens"] | |
| ) | |
| st.markdown("### ๐ฌ Answer") | |
| st.write(result["answer"]) | |
| if result["tokens"]: | |
| t = result["tokens"] | |
| st.caption(f"Tokens โ prompt: {t.get('prompt')}, completion: {t.get('completion')}, total: {t.get('total')}") | |
| show_chunks(result["docs"]) | |
| # ---- JSON Batch Mode ---- | |
| with tab_batch: | |
| st.subheader("Batch from JSON") | |
| # url="https://chatgpt.com/g/g-689b64bc10e88191bca964eea6b438a6-grant-json-builder" | |
| # st.subheader("Grant JSON Builder" % url) | |
| st.link_button("Grant JSON Builder", "https://chatgpt.com/g/g-689b64bc10e88191bca964eea6b438a6-grant-json-builder") | |
| json_file = st.file_uploader("Upload JSON config", type=["json"]) | |
| if json_file: | |
| cfg = json.load(json_file) | |
| st.json(cfg) | |
| if st.button("Run Batch"): | |
| results = [] | |
| # sidebar defaults | |
| defaults = { | |
| "k": settings["k"], | |
| "max_tokens": settings["max_tokens"], | |
| "use_openai": settings["use_openai"], | |
| "topics": settings["topics"], | |
| "topic_score": settings["topic_score"], | |
| "optional_context": manual_context, | |
| } | |
| queries = cfg.get("queries", []) | |
| if not queries: | |
| st.warning("No 'queries' found in JSON.") | |
| for i, item in enumerate(queries, start=1): | |
| q = (item.get("query") or "").strip() | |
| if not q: | |
| st.warning(f"Item {i} missing 'query'; skipping.") | |
| continue | |
| item_settings = resolve_item_settings(item, defaults) | |
| result = run_query( | |
| query=q, | |
| manual_context=item_settings["optional_context"] or manual_context, # global context | |
| vectorstore=vectorstore, | |
| use_openai=item_settings["use_openai"], | |
| k=item_settings["k"], | |
| topic_list=item_settings["topics"], | |
| topic_alpha=item_settings["topic_score"], | |
| max_tokens=item_settings["max_tokens"], | |
| ) | |
| st.markdown(f"## ๐งฉ Query {i}") | |
| st.markdown(f"**Prompt:** {q}") | |
| st.caption( | |
| f"Settings โ use_openai={item_settings['use_openai']}, " | |
| f"k={item_settings['k']}, max_tokens={item_settings['max_tokens']}, " | |
| f"topic_weight={item_settings['topic_score']}, topics={item_settings['topics']}" | |
| ) | |
| st.markdown(result["answer"]) | |
| if result["tokens"]: | |
| t = result["tokens"] | |
| st.caption(f"Tokens โ prompt: {t.get('prompt')}, completion: {t.get('completion')}, total: {t.get('total')}") | |
| show_chunks(result["docs"]) | |
| results.append({ | |
| "query": q, | |
| "settings": item_settings, | |
| "answer": result["answer"], | |
| "tokens": result["tokens"] | |
| }) | |
| st.download_button( | |
| "๐พ Download results JSON", | |
| data=json.dumps({"results": results}, indent=2), | |
| file_name="grantbuddy_results.json", | |
| mime="application/json" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |
| # def main(): | |
| # if not gate_ui(): | |
| # return | |
| # # st.set_page_config(page_title="Grant Buddy RAG", page_icon="๐ค") | |
| # st.title("๐ค Grant Buddy: Grant-Writing Assistant") | |
| # USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False) | |
| # st.sidebar.markdown("### Retrieval Settings") | |
| # k_value = st.sidebar.slider("How many chunks to retrieve (k)", min_value=5, max_value=40, step=5, value=10) | |
| # score_threshold = st.sidebar.slider("Minimum relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.75) | |
| # topic_input=st.sidebar.text_input("Optional: Focus on specific topics (comma-separated)") | |
| # topics=[t.strip() for t in topic_input.split(",") if t.strip()] | |
| # topic_weight= st.sidebar.slider("Topic relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.2) | |
| # st.sidebar.markdown("### Generation Settings") | |
| # max_tokens = st.sidebar.number_input("Max tokens in response", min_value=100, max_value=1500, value=700, step=50) | |
| # if "generated_queries" not in st.session_state: | |
| # st.session_state.generated_queries = {} | |
| # manual_context = st.text_area("๐ Optional: Add your own context (e.g., mission, goals)", height=150) | |
| # # # retriever = init_vector_search().as_retriever(search_kwargs={"k": k_value, "score_threshold": score_threshold}) | |
| # retriever = init_vector_search().as_retriever() | |
| # vectorstore = init_vector_search() | |
| # # pre_k = k_value*4 # Retrieve more chunks first | |
| # # context_docs = retriever.get_relevant_documents(query, k=pre_k) | |
| # # if topics: | |
| # # context_docs = rerank_with_topics(context_docs, topics, alpha=topic_weight) | |
| # # context_docs = context_docs[:k_value] # Final top-k used in RAG | |
| # rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI, max_tokens=max_tokens) | |
| # uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"]) | |
| # uploaded_text = "" | |
| # if uploaded_file: | |
| # with st.spinner("๐ Processing uploaded file..."): | |
| # if uploaded_file.name.endswith(".pdf"): | |
| # reader = PdfReader(uploaded_file) | |
| # uploaded_text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()]) | |
| # elif uploaded_file.name.endswith(".txt"): | |
| # uploaded_text = uploaded_file.read().decode("utf-8") | |
| # # extract qs and headers using llms | |
| # questions = extract_with_llm_local(uploaded_text, use_openai=USE_OPENAI) | |
| # # filter out irrelevant text | |
| # def is_meaningful_prompt(text: str) -> bool: | |
| # too_short = len(text.strip()) < 10 | |
| # banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"] | |
| # contains_bad_word = any(word in text.lower() for word in banned_keywords) | |
| # is_just_punctuation = all(c in ":.*- " for c in text.strip()) | |
| # return not (too_short or contains_bad_word or is_just_punctuation) | |
| # filtered_questions = [q for q in questions if is_meaningful_prompt(q)] | |
| # with st.form("question_selection_form"): | |
| # st.subheader("Choose prompts to answer:") | |
| # selected_questions=[] | |
| # for i,q in enumerate(filtered_questions): | |
| # if st.checkbox(q, key=f"q_{i}", value=True): | |
| # selected_questions.append(q) | |
| # submit_button = st.form_submit_button("Submit") | |
| # #Multi-Select Question | |
| # if 'submit_button' in locals() and submit_button: | |
| # if selected_questions: | |
| # with st.spinner("๐ก Generating answers..."): | |
| # answers = [] | |
| # for q in selected_questions: | |
| # combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()])) | |
| # pre_k=k_value*4 | |
| # context_docs=retriever.get_relevant_documents(q, k=pre_k) | |
| # if topics: | |
| # context_docs=rerank_with_topics(context_docs,topics,alpha=topic_weight) | |
| # context_docs=context_docs[:k_value] | |
| # # full_query = f"{q}\n\nAdditional context:\n{uploaded_text}" | |
| # if q in st.session_state.generated_queries: | |
| # response = st.session_state.generated_queries[q] | |
| # else: | |
| # response = rag_chain.invoke({ | |
| # "question": q, | |
| # "manual_context": combined_context, | |
| # "context_docs": context_docs | |
| # }) | |
| # st.session_state.generated_queries[q] = response | |
| # answers.append({"question": q, "answer": response}) | |
| # for item in answers: | |
| # st.markdown(f"### โ {item['question']}") | |
| # st.markdown(f"๐ฌ {item['answer']['answer']}") | |
| # tokens = item['answer'].get("tokens", {}) | |
| # if tokens: | |
| # st.markdown(f"๐งฎ **Token Usage:** Prompt = {tokens.get('prompt')}, " | |
| # f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}") | |
| # else: | |
| # st.info("No prompts selected for answering.") | |
| # # โ๏ธ Manual single-question input | |
| # query = st.text_input("Ask a grant-related question") | |
| # if st.button("Submit"): | |
| # if not query: | |
| # st.warning("Please enter a question.") | |
| # return | |
| # # full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query | |
| # pre_k = k_value * 4 | |
| # context_docs=retriever.get_relevant_documents(query, k=pre_k) | |
| # if topics: | |
| # context_docs=rerank_with_topics(context_docs, topics, alpha=topic_weight) | |
| # context_docs = context_docs[:k_value] | |
| # combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()])) | |
| # with st.spinner("๐ค Thinking..."): | |
| # # response = rag_chain.invoke(full_query) | |
| # response = rag_chain.invoke({"question":query,"manual_context": combined_context, "context_docs": context_docs}) | |
| # st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True) | |
| # tokens=response.get("tokens",{}) | |
| # if tokens: | |
| # st.markdown(f"๐งฎ **Token Usage:** Prompt = {tokens.get('prompt')}, " | |
| # f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}") | |
| # with st.expander("๐ Retrieved Chunks"): | |
| # # context_docs = retriever.get_relevant_documents(query) | |
| # for doc in context_docs: | |
| # # st.json(doc.metadata) | |
| # st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}") | |
| # st.markdown(doc.page_content[:700] + "...") | |
| # if topics: | |
| # matched_topics=set(doc.metadata['metadata'].get('topics',[])).intersection(topics) | |
| # st.markdown(f"**Matched Topics** {','.join(matched_topics)}") | |
| # st.markdown("---") | |
| # if __name__ == "__main__": | |
| # main() | |
| # # app.py | |
| # import os | |
| # import re | |
| # import openai | |
| # from huggingface_hub import InferenceClient | |
| # import json | |
| # from huggingface_hub import HfApi | |
| # import streamlit as st | |
| # from typing import List, Dict, Any | |
| # from urllib.parse import quote_plus | |
| # from pymongo import MongoClient | |
| # from PyPDF2 import PdfReader | |
| # st.set_page_config(page_title="Grant Buddy RAG", page_icon="๐ค") | |
| # from typing import List | |
| # from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings | |
| # from langchain.embeddings import HuggingFaceEmbeddings | |
| # from langchain_community.vectorstores import MongoDBAtlasVectorSearch | |
| # from langchain.prompts import PromptTemplate | |
| # from langchain.schema import Document | |
| # from langchain.schema.runnable import RunnableLambda, RunnablePassthrough | |
| # from huggingface_hub import InferenceClient | |
| # # =================== Secure Env via Hugging Face Secrets =================== | |
| # user = quote_plus(os.getenv("MONGO_USERNAME")) | |
| # password = quote_plus(os.getenv("MONGO_PASSWORD")) | |
| # cluster = os.getenv("MONGO_CLUSTER") | |
| # db_name = os.getenv("MONGO_DB_NAME", "files") | |
| # collection_name = os.getenv("MONGO_COLLECTION", "files_collection") | |
| # index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index") | |
| # HF_TOKEN = os.getenv("HF_TOKEN") | |
| # OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip() | |
| # if OPENAI_API_KEY: | |
| # openai.api_key = OPENAI_API_KEY | |
| # from openai import OpenAI | |
| # client = OpenAI(api_key=OPENAI_API_KEY) | |
| # # MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority" | |
| # MONGO_URI = f"mongodb+srv://{user}:{password}@{cluster}/{db_name}?retryWrites=true&w=majority&tls=true&tlsAllowInvalidCertificates=true" | |
| # # =================== Prompt =================== | |
| # grantbuddy_prompt = PromptTemplate.from_template( | |
| # """You are Grant Buddy, a specialized language model fine-tuned with instruction-tuning and RLHF. | |
| # You help a nonprofit focused on social entrepreneurship, BIPOC empowerment, and edtech write clear, mission-aligned grant responses. | |
| # **Instructions:** | |
| # - Start with reasoning or context for your answer. | |
| # - Always align with the nonprofitโs mission. | |
| # - Use structured formatting: headings, bullet points, numbered lists. | |
| # - Include impact data or examples if relevant. | |
| # - Do NOT repeat the same sentence or answer multiple times. | |
| # - If no answer exists in the context, say: "This information is not available in the current context." | |
| # CONTEXT: | |
| # {context} | |
| # QUESTION: | |
| # {question} | |
| # """ | |
| # ) | |
| # # =================== Vector Search Setup =================== | |
| # @st.cache_resource | |
| # def init_embedding_model(): | |
| # return HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| # @st.cache_resource | |
| # def init_vector_search() -> MongoDBAtlasVectorSearch: | |
| # HF_TOKEN = os.getenv("HF_TOKEN", "").strip() | |
| # model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
| # st.write(f"๐ Connecting to Hugging Face model: `{model_name}`") | |
| # embedding_model = HuggingFaceEmbeddings(model_name=model_name) | |
| # # โ Manual MongoClient with TLS settings | |
| # user = quote_plus(os.getenv("MONGO_USERNAME", "").strip()) | |
| # password = quote_plus(os.getenv("MONGO_PASSWORD", "").strip()) | |
| # cluster = os.getenv("MONGO_CLUSTER", "").strip() | |
| # db_name = os.getenv("MONGO_DB_NAME", "files").strip() | |
| # collection_name = os.getenv("MONGO_COLLECTION", "files_collection").strip() | |
| # index_name = os.getenv("MONGO_VECTOR_INDEX", "vector_index").strip() | |
| # mongo_uri = f"mongodb+srv://{user}:{password}@{cluster}/?retryWrites=true&w=majority" | |
| # try: | |
| # client = MongoClient(mongo_uri, tls=True, tlsAllowInvalidCertificates=True, serverSelectionTimeoutMS=20000) | |
| # db = client[db_name] | |
| # collection = db[collection_name] | |
| # st.success("โ MongoClient connected successfully") | |
| # return MongoDBAtlasVectorSearch( | |
| # collection=collection, | |
| # embedding=embedding_model, | |
| # index_name=index_name, | |
| # ) | |
| # except Exception as e: | |
| # st.error("โ Failed to connect to MongoDB Atlas manually") | |
| # st.error(str(e)) | |
| # raise e | |
| # # =================== Question/Headers Extraction =================== | |
| # # def extract_questions_and_headers(text: str) -> List[str]: | |
| # # header_patterns = [ | |
| # # r'\d+\.\s+\*\*([^\*]+)\*\*', | |
| # # r'\*\*([^*]+)\*\*', | |
| # # r'^([A-Z][^a-z]*[A-Z])$', | |
| # # r'^([A-Z][A-Za-z\s]{3,})$', | |
| # # r'^[A-Z][A-Za-z\s]+:$' | |
| # # ] | |
| # # question_patterns = [ | |
| # # r'^.+\?$', | |
| # # r'^\*?Please .+', | |
| # # r'^How .+', | |
| # # r'^What .+', | |
| # # r'^Describe .+', | |
| # # ] | |
| # # combined_header_re = re.compile("|".join(header_patterns), re.MULTILINE) | |
| # # combined_question_re = re.compile("|".join(question_patterns), re.MULTILINE) | |
| # # headers = [match for group in combined_header_re.findall(text) for match in group if match] | |
| # # questions = combined_question_re.findall(text) | |
| # # return headers + questions | |
| # # def extract_with_llm(text: str) -> List[str]: | |
| # # client = InferenceClient(api_key=HF_TOKEN.strip()) | |
| # # try: | |
| # # response = client.chat.completions.create( | |
| # # model="mistralai/Mistral-Nemo-Instruct-2407", # or "HuggingFaceH4/zephyr-7b-beta" | |
| # # messages=[ | |
| # # { | |
| # # "role": "system", | |
| # # "content": "You are an assistant helping extract questions and headers from grant applications.", | |
| # # }, | |
| # # { | |
| # # "role": "user", | |
| # # "content": ( | |
| # # "Please extract all the grant application headers and questions from the following text. " | |
| # # "Include section titles, prompts, and any question-like content. Return them as a numbered list.\n\n" | |
| # # f"{text[:3000]}" | |
| # # ), | |
| # # }, | |
| # # ], | |
| # # temperature=0.2, | |
| # # max_tokens=512, | |
| # # ) | |
| # # return [ | |
| # # line.strip("โข-1234567890. ").strip() | |
| # # for line in response.choices[0].message.content.strip().split("\n") | |
| # # if line.strip() | |
| # # ] | |
| # # except Exception as e: | |
| # # st.error("โ LLM extraction failed") | |
| # # st.error(str(e)) | |
| # # return [] | |
| # # def extract_with_llm_local(text: str) -> List[str]: | |
| # # prompt = ( | |
| # # "You are an assistant helping extract useful questions and section headers from a grant application.\n" | |
| # # "Return only the important prompts as a numbered list.\n\n" | |
| # # "TEXT:\n" | |
| # # f"{text[:3000]}\n\n" | |
| # # "PROMPTS:" | |
| # # ) | |
| # # inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
| # # outputs = model.generate( | |
| # # **inputs, | |
| # # max_new_tokens=512, | |
| # # temperature=0.3, | |
| # # do_sample=False | |
| # # ) | |
| # # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # # # Extract prompts from the numbered list in the output | |
| # # lines = raw_output.split("\n") | |
| # # prompts = [] | |
| # # for line in lines: | |
| # # line = line.strip("โข-1234567890. ").strip() | |
| # # if len(line) > 10: | |
| # # prompts.append(line) | |
| # # return prompts | |
| # # def extract_with_llm_local(text: str) -> List[str]: | |
| # # example_text = """TEXT: | |
| # # 1. Project Summary: Please describe the main goals of your project. | |
| # # 2. Contact Information: Address, phone, email. | |
| # # 3. What is the mission of your organization? | |
| # # 4. Who are the beneficiaries? | |
| # # 5. Budget Breakdown | |
| # # 6. Please describe how the funding will be used. | |
| # # 7. Website: www.example.org | |
| # # PROMPTS: | |
| # # 1. Project Summary | |
| # # 2. What is the mission of your organization? | |
| # # 3. Who are the beneficiaries? | |
| # # 4. Please describe how the funding will be used. | |
| # # """ | |
| # # prompt = ( | |
| # # "You are an assistant helping extract important grant application prompts and section headers.\n" | |
| # # "Return only questions and meaningful section titles that require thoughtful answers.\n" | |
| # # "Avoid metadata like phone numbers, dates, contact info, or websites.\n\n" | |
| # # f"{example_text}\n" | |
| # # f"TEXT:\n{text[:3000]}\n\n" | |
| # # "PROMPTS:" | |
| # # ) | |
| # # inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
| # # outputs = model.generate( | |
| # # **inputs, | |
| # # max_new_tokens=512, | |
| # # temperature=0.3, | |
| # # do_sample=False | |
| # # ) | |
| # # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # # # Clean and extract numbered or bulleted lines | |
| # # lines = raw_output.split("\n") | |
| # # prompts = [] | |
| # # for line in lines: | |
| # # clean = line.strip("โข-1234567890. ").strip() | |
| # # if len(clean) > 10 and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]): | |
| # # prompts.append(clean) | |
| # # return prompts | |
| # def extract_with_llm_local(text: str, use_openai: bool = False) -> List[str]: | |
| # # Example context to prime the model | |
| # example_text = """TEXT: | |
| # 1. Project Summary: Please describe the main goals of your project. | |
| # 2. Contact Information: Address, phone, email. | |
| # 3. What is the mission of your organization? | |
| # 4. Who are the beneficiaries? | |
| # 5. Budget Breakdown | |
| # 6. Please describe how the funding will be used. | |
| # 7. Website: www.example.org | |
| # PROMPTS: | |
| # 1. Project Summary | |
| # 2. What is the mission of your organization? | |
| # 3. Who are the beneficiaries? | |
| # 4. Please describe how the funding will be used. | |
| # """ | |
| # prompt = ( | |
| # "You are an assistant helping extract important grant application prompts and section headers.\n" | |
| # "Return only questions and meaningful section titles that require thoughtful answers.\n" | |
| # "Avoid metadata like phone numbers, dates, contact info, or websites.\n\n" | |
| # f"{example_text}\n" | |
| # f"TEXT:\n{text[:3000]}\n\n" | |
| # "PROMPTS:" | |
| # ) | |
| # if use_openai: | |
| # if not openai.api_key: | |
| # st.error("โ OPENAI_API_KEY is not set.") | |
| # return "โ ๏ธ OpenAI key missing." | |
| # try: | |
| # response = client.chat.completions.create( | |
| # model="gpt-4o-mini", | |
| # messages=[ | |
| # {"role": "system", "content": "You extract prompts and headers from grant text."}, | |
| # {"role": "user", "content": prompt}, | |
| # ], | |
| # temperature=0.2, | |
| # max_tokens=500, | |
| # ) | |
| # # raw_output = response["choices"][0]["message"]["content"] | |
| # raw_output = response.choices[0].message.content | |
| # st.markdown(f"๐งฎ Extract Tokens: Prompt = {response.usage.prompt_tokens}, " | |
| # f"Completion = {response.usage.completion_tokens}, Total = {response.usage.total_tokens}") | |
| # except Exception as e: | |
| # st.error(f"โ OpenAI extraction failed: {e}") | |
| # return [] | |
| # else: | |
| # inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
| # outputs = model.generate( | |
| # **inputs, | |
| # max_new_tokens=min(ax_tokens,512), | |
| # temperature=0.3, | |
| # do_sample=False, | |
| # pad_token_id=tokenizer.eos_token_id | |
| # ) | |
| # raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # # Clean and deduplicate prompts | |
| # lines = raw_output.split("\n") | |
| # prompts = [] | |
| # seen = set() | |
| # for line in lines: | |
| # clean = line.strip("โข-1234567890. ").strip() | |
| # if ( | |
| # len(clean) > 10 | |
| # and not any(bad in clean.lower() for bad in ["phone", "email", "address", "website"]) | |
| # and clean not in seen | |
| # ): | |
| # prompts.append(clean) | |
| # seen.add(clean) | |
| # return prompts | |
| # # def is_meaningful_prompt(text: str) -> bool: | |
| # # too_short = len(text.strip()) < 10 | |
| # # banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"] | |
| # # contains_bad_word = any(word in text.lower() for word in banned_keywords) | |
| # # is_just_punctuation = all(c in ":.*- " for c in text.strip()) | |
| # # return not (too_short or contains_bad_word or is_just_punctuation) | |
| # # =================== Format Retrieved Chunks =================== | |
| # def format_docs(docs: List[Document]) -> str: | |
| # return "\n\n".join(doc.page_content or doc.metadata.get("content", "") for doc in docs) | |
| # # =================== Generate Response from Hugging Face Model =================== | |
| # # def generate_response(input_dict: Dict[str, Any]) -> str: | |
| # # client = InferenceClient(api_key=HF_TOKEN.strip()) | |
| # # prompt = grantbuddy_prompt.format(**input_dict) | |
| # # try: | |
| # # response = client.chat.completions.create( | |
| # # model="HuggingFaceH4/zephyr-7b-beta", | |
| # # messages=[ | |
| # # {"role": "system", "content": prompt}, | |
| # # {"role": "user", "content": input_dict["question"]}, | |
| # # ], | |
| # # max_tokens=1000, | |
| # # temperature=0.2, | |
| # # ) | |
| # # return response.choices[0].message.content | |
| # # except Exception as e: | |
| # # st.error(f"โ Error from model: {e}") | |
| # # return "โ ๏ธ Failed to generate response. Please check your model, HF token, or request format." | |
| # from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # import torch | |
| # @st.cache_resource | |
| # def load_local_model(): | |
| # model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| # tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # model = AutoModelForCausalLM.from_pretrained(model_name) | |
| # return tokenizer, model | |
| # tokenizer, model = load_local_model() | |
| # def generate_response(input_dict, use_openai=False, max_tokens=700): | |
| # prompt = grantbuddy_prompt.format(**input_dict) | |
| # if use_openai: | |
| # try: | |
| # response = client.chat.completions.create( | |
| # model="gpt-4o-mini", | |
| # messages=[ | |
| # {"role": "system", "content": prompt}, | |
| # {"role": "user", "content": input_dict["question"]}, | |
| # ], | |
| # temperature=0.2, | |
| # max_tokens=max_tokens, | |
| # ) | |
| # answer = response.choices[0].message.content.strip() | |
| # # โ Token logging | |
| # prompt_tokens = response.usage.prompt_tokens | |
| # completion_tokens = response.usage.completion_tokens | |
| # total_tokens = response.usage.total_tokens | |
| # return { | |
| # "answer": answer, | |
| # "tokens": { | |
| # "prompt": prompt_tokens, | |
| # "completion": completion_tokens, | |
| # "total": total_tokens | |
| # } | |
| # } | |
| # except Exception as e: | |
| # st.error(f"โ OpenAI error: {e}") | |
| # return { | |
| # "answer": "โ ๏ธ OpenAI request failed.", | |
| # "tokens": {} | |
| # } | |
| # else: | |
| # inputs = tokenizer(prompt, return_tensors="pt") | |
| # outputs = model.generate( | |
| # **inputs, | |
| # max_new_tokens=512, | |
| # temperature=0.7, | |
| # do_sample=True, | |
| # pad_token_id=tokenizer.eos_token_id | |
| # ) | |
| # decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # return { | |
| # "answer": decoded[len(prompt):].strip(), | |
| # "tokens": {} | |
| # } | |
| # # =================== RAG Chain =================== | |
| # def get_rag_chain(retriever, use_openai=False, max_tokens=700): | |
| # def merge_contexts(inputs): | |
| # retrieved_chunks = format_docs(retriever.invoke(inputs["question"])) | |
| # combined = "\n\n".join(filter(None, [ | |
| # inputs.get("manual_context", ""), | |
| # retrieved_chunks | |
| # ])) | |
| # return { | |
| # "context": combined, | |
| # "question": inputs["question"] | |
| # } | |
| # return RunnableLambda(merge_contexts) | RunnableLambda( | |
| # lambda input_dict: generate_response(input_dict, use_openai=use_openai, max_tokens=max_tokens) | |
| # ) | |
| # # =================== Streamlit UI =================== | |
| # def main(): | |
| # # st.set_page_config(page_title="Grant Buddy RAG", page_icon="๐ค") | |
| # st.title("๐ค Grant Buddy: Grant-Writing Assistant") | |
| # USE_OPENAI = st.sidebar.checkbox("Use OpenAI (Costs Tokens)", value=False) | |
| # st.sidebar.markdown("### Retrieval Settings") | |
| # k_value = st.sidebar.slider("How many chunks to retrieve (k)", min_value=5, max_value=40, step=5, value=10) | |
| # score_threshold = st.sidebar.slider("Minimum relevance score", min_value=0.0, max_value=1.0, step=0.05, value=0.75) | |
| # st.sidebar.markdown("### Generation Settings") | |
| # max_tokens = st.sidebar.number_input("Max tokens in response", min_value=100, max_value=1500, value=700, step=50) | |
| # if "generated_queries" not in st.session_state: | |
| # st.session_state.generated_queries = {} | |
| # manual_context = st.text_area("๐ Optional: Add your own context (e.g., mission, goals)", height=150) | |
| # retriever = init_vector_search().as_retriever(search_kwargs={"k": k_value, "score_threshold": score_threshold}) | |
| # rag_chain = get_rag_chain(retriever, use_openai=USE_OPENAI, max_tokens=max_tokens) | |
| # uploaded_file = st.file_uploader("Upload PDF or TXT for extra context (optional)", type=["pdf", "txt"]) | |
| # uploaded_text = "" | |
| # if uploaded_file: | |
| # with st.spinner("๐ Processing uploaded file..."): | |
| # if uploaded_file.name.endswith(".pdf"): | |
| # reader = PdfReader(uploaded_file) | |
| # uploaded_text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()]) | |
| # elif uploaded_file.name.endswith(".txt"): | |
| # uploaded_text = uploaded_file.read().decode("utf-8") | |
| # # extract qs and headers using llms | |
| # questions = extract_with_llm_local(uploaded_text, use_openai=USE_OPENAI) | |
| # # filter out irrelevant text | |
| # def is_meaningful_prompt(text: str) -> bool: | |
| # too_short = len(text.strip()) < 10 | |
| # banned_keywords = ["phone", "email", "fax", "address", "date", "contact", "website"] | |
| # contains_bad_word = any(word in text.lower() for word in banned_keywords) | |
| # is_just_punctuation = all(c in ":.*- " for c in text.strip()) | |
| # return not (too_short or contains_bad_word or is_just_punctuation) | |
| # filtered_questions = [q for q in questions if is_meaningful_prompt(q)] | |
| # with st.form("question_selection_form"): | |
| # st.subheader("Choose prompts to answer:") | |
| # selected_questions=[] | |
| # for i,q in enumerate(filtered_questions): | |
| # if st.checkbox(q, key=f"q_{i}", value=True): | |
| # selected_questions.append(q) | |
| # submit_button = st.form_submit_button("Submit") | |
| # #Multi-Select Question | |
| # if 'submit_button' in locals() and submit_button: | |
| # if selected_questions: | |
| # with st.spinner("๐ก Generating answers..."): | |
| # answers = [] | |
| # for q in selected_questions: | |
| # # full_query = f"{q}\n\nAdditional context:\n{uploaded_text}" | |
| # combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()])) | |
| # if q in st.session_state.generated_queries: | |
| # response = st.session_state.generated_queries[q] | |
| # else: | |
| # response = rag_chain.invoke({ | |
| # "question": q, | |
| # "manual_context": combined_context | |
| # }) | |
| # st.session_state.generated_queries[q] = response | |
| # answers.append({"question": q, "answer": response}) | |
| # for item in answers: | |
| # st.markdown(f"### โ {item['question']}") | |
| # st.markdown(f"๐ฌ {item['answer']['answer']}") | |
| # tokens = item['answer'].get("tokens", {}) | |
| # if tokens: | |
| # st.markdown(f"๐งฎ **Token Usage:** Prompt = {tokens.get('prompt')}, " | |
| # f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}") | |
| # else: | |
| # st.info("No prompts selected for answering.") | |
| # # โ๏ธ Manual single-question input | |
| # query = st.text_input("Ask a grant-related question") | |
| # if st.button("Submit"): | |
| # if not query: | |
| # st.warning("Please enter a question.") | |
| # return | |
| # # full_query = f"{query}\n\nAdditional context:\n{uploaded_text}" if uploaded_text else query | |
| # combined_context = "\n\n".join(filter(None, [manual_context.strip(), uploaded_text.strip()])) | |
| # with st.spinner("๐ค Thinking..."): | |
| # # response = rag_chain.invoke(full_query) | |
| # response = rag_chain.invoke({"question":query,"manual_context": combined_context}) | |
| # st.text_area("Grant Buddy says:", value=response["answer"], height=250, disabled=True) | |
| # tokens=response.get("tokens",{}) | |
| # if tokens: | |
| # st.markdown(f"๐งฎ **Token Usage:** Prompt = {tokens.get('prompt')}, " | |
| # f"Completion = {tokens.get('completion')}, Total = {tokens.get('total')}") | |
| # with st.expander("๐ Retrieved Chunks"): | |
| # context_docs = retriever.get_relevant_documents(query) | |
| # for doc in context_docs: | |
| # # st.json(doc.metadata) | |
| # st.markdown(f"**Chunk ID:** {doc.metadata.get('chunk_id', 'unknown')} | **Title:** {doc.metadata['metadata'].get('title', 'unknown')}") | |
| # st.markdown(doc.page_content[:700] + "...") | |
| # st.markdown("---") | |
| # if __name__ == "__main__": | |
| # main() | |