Spaces:
Running
Running
| """SLM with RAG for financial statements""" | |
| # Importing the dependencies | |
| import logging | |
| import os | |
| import subprocess | |
| import time | |
| import re | |
| import pickle | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import spacy | |
| import pdfplumber | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import gradio as gr | |
| import faiss | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from data_filters import ( | |
| restricted_patterns, | |
| restricted_topics, | |
| FINANCIAL_DATA_PATTERNS, | |
| FINANCIAL_ENTITY_LABELS, | |
| GENERAL_KNOWLEDGE_PATTERNS, | |
| sensitive_terms, | |
| EXPLANATORY_PATTERNS, | |
| FINANCIAL_TERMS, | |
| ) | |
| # Initialize logger | |
| logging.basicConfig( | |
| # filename="app.log", | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger() | |
| os.makedirs("data", exist_ok=True) | |
| # SLM: Microsoft PHI-2 model is loaded | |
| # It does have higher memory and compute requirements compared to TinyLlama and Falcon | |
| # But it gives the best results among the three | |
| DEVICE = "cpu" # or cuda | |
| # DEVICE = "cuda" # or cuda | |
| # MODEL_NAME = "TinyLlama/TinyLlama_v1.1" | |
| # MODEL_NAME = "tiiuae/falcon-rw-1b" | |
| MODEL_NAME = "microsoft/phi-2" | |
| # MODEL_NAME = "google/gemma-3-1b-pt" | |
| # Load the Tokenizer for PHI-2 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| MAX_TOKENS = tokenizer.model_max_length | |
| CONTEXT_MULTIPLIER = 0.7 | |
| # The max_context tokens is used to limit the retrieved chunks during querying | |
| # to provide some headroom for the query | |
| MAX_CONTEXT_TOKENS = int(MAX_TOKENS * CONTEXT_MULTIPLIER) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Since the model is to be hosted on a cpu instance, we use float32 | |
| # For GPU, we can use float16 or bfloat16 | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, torch_dtype=torch.float32, trust_remote_code=True | |
| ).to(DEVICE) | |
| model.eval() | |
| # model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) | |
| logger.info("Model loaded successfully.") | |
| # Load Sentence Transformer for Embeddings and Cross Encoder for re-ranking | |
| embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
| cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2") | |
| # Load spaCy English model for Named Entity Recognition (mainly for guardrail) | |
| try: | |
| nlp = spacy.load("en_core_web_sm") | |
| except OSError: | |
| subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"]) | |
| nlp = spacy.load("en_core_web_sm") | |
| # Extract the yaer from the upload file's name if any | |
| def extract_year_from_filename(filename): | |
| """Extract Year from Filename""" | |
| match = re.search(r"(\d{4})-(\d{4})", filename) | |
| if match: | |
| return match.group(1) | |
| match = re.search(r"(\d{4})", filename) | |
| return match.group(1) if match else "Unknown" | |
| # Use PDFPlumber to extract the tables from the uploaded file | |
| # Add the year column for context and create a dataframe | |
| def extract_tables_from_pdf(pdf_path): | |
| """Extract tables from PDF into a DataFrame""" | |
| all_tables = [] | |
| report_year = extract_year_from_filename(pdf_path) | |
| with pdfplumber.open(pdf_path) as pdf: | |
| for page_num, page in enumerate(pdf.pages, start=1): | |
| tables = page.extract_tables() | |
| for table in tables: | |
| df = pd.DataFrame(table) | |
| df["year"] = report_year | |
| all_tables.append(df) | |
| return pd.concat(all_tables, ignore_index=True) if all_tables else pd.DataFrame() | |
| # Load the csv files directly using pandas into a dataframe | |
| def load_csv(file_path): | |
| """Loads a CSV file into a DataFrame""" | |
| try: | |
| df = pd.read_csv(file_path) | |
| df["year"] = extract_year_from_filename(file_path) | |
| return df | |
| except Exception as e: | |
| print(f"Error loading CSV: {e}") | |
| return None | |
| # Preprocess the dataframe - Replace null values and create text rows suitable for chunking | |
| def clean_dataframe_text(df): | |
| """Clean and format PDF/CSV data""" | |
| df.fillna("", inplace=True) | |
| text_data = [] | |
| for _, row in df.iterrows(): | |
| parts = [] | |
| if "year" in df.columns: | |
| parts.append(f"Year: {row['year']}") | |
| parts.extend([str(val).strip() for val in row if str(val).strip()]) | |
| text_data.append(", ".join(parts)) | |
| df["text"] = text_data | |
| return df[["text"]].replace("", np.nan).dropna() | |
| # Chunk the text for retrival | |
| # Different chunk sizes - 256,512,1024,2048 were tried and 512 worked the best for financial RAG | |
| def chunk_text(text, chunk_size=512): | |
| """Apply Chunking on the text""" | |
| words = text.split() | |
| chunks, temp_chunk = [], [] | |
| for word in words: | |
| if sum(len(w) for w in temp_chunk) + len(temp_chunk) + len(word) <= chunk_size: | |
| temp_chunk.append(word) | |
| else: | |
| chunks.append(" ".join(temp_chunk)) | |
| temp_chunk = [word] | |
| if temp_chunk: | |
| chunks.append(" ".join(temp_chunk)) | |
| return chunks | |
| # Uses regex to identify financial terms and ensure relevant data is only merged | |
| def is_financial_text(text): | |
| """Detects financial data""" | |
| return bool( | |
| re.search( | |
| FINANCIAL_DATA_PATTERNS, | |
| text, | |
| re.IGNORECASE, | |
| ) | |
| ) | |
| # Uses a sentence transformer "all-MiniLM-L6-v2" to embed text chunks | |
| # Stores embeddings in a FAISS vector database for similarity search | |
| # BM25 is implemented alongside FAISS to improve retrieval | |
| # Use FAISS Cosine Similarity index and merge only highly similar text chunks (>85%) | |
| def merge_similar_chunks(chunks, similarity_threshold=0.85): | |
| """Merge similar chunks while preserving financial data structure""" | |
| if not chunks: | |
| return [] | |
| # Encode chunks into embeddings | |
| embeddings = np.array( | |
| embed_model.encode(chunks, normalize_embeddings=True), dtype="float32" | |
| ) | |
| # FAISS Cosine Similarity Index | |
| index = faiss.IndexFlatIP(embeddings.shape[1]) | |
| index.add(embeddings) | |
| # Get top-2 most similar chunks | |
| _, indices = index.search(embeddings, 2) | |
| merged_chunks = {} | |
| for i, idx in enumerate(indices[:, 1]): | |
| if i in merged_chunks or idx in merged_chunks: | |
| continue | |
| sim_score = np.dot(embeddings[i], embeddings[idx]) | |
| # Ensure financial data isn't incorrectly merged | |
| if is_financial_text(chunks[i]) or is_financial_text(chunks[idx]): | |
| merged_chunks[i] = chunks[i] | |
| merged_chunks[idx] = chunks[idx] | |
| continue | |
| # Merge only if similarity is high and chunks are adjacent | |
| if sim_score > similarity_threshold and abs(i - idx) == 1: | |
| merged_chunks[i] = chunks[i] + " " + chunks[idx] | |
| merged_chunks[idx] = merged_chunks[i] | |
| else: | |
| merged_chunks[i] = chunks[i] | |
| return list(set(merged_chunks.values())) | |
| # Handle for file upload button in UI | |
| # Processes the uploaded files and generates the embeddings | |
| # The FAISS embeddings and tokenized chunks are saved for retrieval | |
| def process_files(files, chunk_size=512): | |
| """Process uploaded files and generate embeddings""" | |
| if not files: | |
| logger.warning("No files uploaded!") | |
| return "Please upload at least one PDF or CSV file." | |
| pdf_paths = [file.name for file in files if file.name.endswith(".pdf")] | |
| csv_paths = [file.name for file in files if file.name.endswith(".csv")] | |
| logger.info(f"Processing {len(pdf_paths)} PDFs and {len(csv_paths)} CSVs") | |
| df_list = [] | |
| if pdf_paths: | |
| df_list.extend([extract_tables_from_pdf(pdf) for pdf in pdf_paths]) | |
| for csv in csv_paths: | |
| df = load_csv(csv) | |
| df_list.append(df) | |
| if not df_list: | |
| logger.warning("No valid data found in the uploaded files") | |
| return "No valid data found in the uploaded files" | |
| df = pd.concat(df_list, ignore_index=True) | |
| df.dropna(how="all", inplace=True) | |
| logger.info("Data extracted from the files") | |
| df_cleaned = clean_dataframe_text(df) | |
| df_cleaned["chunks"] = df_cleaned["text"].apply(lambda x: chunk_text(x, chunk_size)) | |
| df_chunks = df_cleaned.explode("chunks").reset_index(drop=True) | |
| merged_chunks = merge_similar_chunks(df_chunks["chunks"].tolist()) | |
| chunk_texts = merged_chunks | |
| # chunk_texts = df_chunks["chunks"].tolist() | |
| embeddings = np.array( | |
| embed_model.encode(chunk_texts, normalize_embeddings=True), dtype="float32" | |
| ) | |
| # Save FAISS index | |
| index = faiss.IndexFlatL2(embeddings.shape[1]) | |
| index.add(embeddings) | |
| faiss.write_index(index, "data/faiss_index.bin") | |
| logger.info("FAISS index created and saved.") | |
| # Save BM25 index | |
| tokenized_chunks = [text.lower().split() for text in chunk_texts] | |
| bm25_data = {"tokenized_chunks": tokenized_chunks, "chunk_texts": chunk_texts} | |
| logger.info("BM25 index created and saved.") | |
| with open("data/bm25_data.pkl", "wb") as f: | |
| pickle.dump(bm25_data, f) | |
| return "Files processed successfully! You can now query." | |
| def contains_financial_entities(query): | |
| """Check if query contains financial entities""" | |
| doc = nlp(query) | |
| for ent in doc.ents: | |
| if ent.label_ in FINANCIAL_ENTITY_LABELS: | |
| return True | |
| return False | |
| def contains_geographical_entities(query): | |
| """Check if the query contains geographical entities""" | |
| doc = nlp(query) | |
| return any(ent.label_ == "GPE" for ent in doc.ents) | |
| def contains_financial_terms(query): | |
| """Check if the query contains financial terms""" | |
| return any(term in query.lower() for term in FINANCIAL_TERMS) | |
| def is_general_knowledge_query(query): | |
| """Check if query contains general knowledge""" | |
| query_lower = query.lower() | |
| for pattern in GENERAL_KNOWLEDGE_PATTERNS: | |
| if re.search(pattern, query_lower): | |
| return True | |
| return False | |
| def get_latest_available_year(retrieved_chunks): | |
| """Extracts the latest available year from retrieved financial data""" | |
| years = set() | |
| year_pattern = r"\b(20\d{2})\b" | |
| for chunk in retrieved_chunks: | |
| years.update(map(int, re.findall(year_pattern, chunk))) | |
| return max(years) if years else 2024 | |
| def is_irrelevant_query(query): | |
| """Check if the query is not finance related""" | |
| # If the query is general knowledge and not finance-related | |
| if is_general_knowledge_query(query) and not contains_financial_terms(query): | |
| return True | |
| # If the query contains only geographical terms without financial entities | |
| if contains_geographical_entities(query) and not contains_financial_entities(query): | |
| return True | |
| return False | |
| # Input guardrail implementation | |
| # NER + Regex + List of terms used to filter irrelevant queries | |
| # Regex is used to filter queries related to sensitive topics | |
| # Uses spaCy model's Named Entity Recognition to filter queries for personal details | |
| # Uses cosine similarity with the embedded query and sensitive topic vectors | |
| # to filter out queries violating confidential/security rules (additional) | |
| def is_query_allowed(query): | |
| """Checks if the query violates security or confidentiality rules""" | |
| if is_irrelevant_query(query): | |
| return False, "Query is not finance-related. Please ask a financial question." | |
| for pattern in restricted_patterns: | |
| if re.search(pattern, query.lower(), re.IGNORECASE): | |
| return False, "This query requests sensitive or confidential information." | |
| doc = nlp(query) | |
| # Check if there's a person entity and contains sensitive terms | |
| for ent in doc.ents: | |
| if ent.label_ == "PERSON": | |
| for token in ent.subtree: | |
| if token.text.lower() in sensitive_terms: | |
| return ( | |
| False, | |
| "Query contains personal salary information, which is restricted.", | |
| ) | |
| query_embedding = embed_model.encode(query, normalize_embeddings=True) | |
| topic_embeddings = embed_model.encode( | |
| list(restricted_topics), normalize_embeddings=True | |
| ) | |
| # Check similarities between the restricted topics and the query | |
| similarities = np.dot(topic_embeddings, query_embedding) | |
| if np.max(similarities) > 0.85: | |
| return False, "This query requests sensitive or confidential information." | |
| return True, None | |
| # Boosts the scores for texts containing financial terms | |
| # This is useful during re-ranking | |
| def boost_score(text, base_score, boost_factor=1.2): | |
| """Boost scores if the text contains financial terms""" | |
| if any(term in text.lower() for term in FINANCIAL_TERMS): | |
| return base_score * boost_factor | |
| return base_score | |
| # FAISS embeddings are used to retrieve semantically similar chunks | |
| # BM25 is used to retrieve relevant chunks based on the keywords (TF-IDF) | |
| # FAISS and BM25 complement each other- similar matches and important exact matches | |
| # The retrieved chunks are merged and sorted based on a lambda FAISS value | |
| # if lambda FAISS is 0.6, weightage for retrieved FAISS chunks are 0.6 and 0.4 for BM25 chunks | |
| # Cross encoder model ms-marco-MiniLM-L6-v2 is used for scoring and re-ranking the chunks | |
| def hybrid_retrieve(query, chunk_texts, index, bm25, top_k=5, lambda_faiss=0.7): | |
| """Hybrid Retrieval with FAISS, BM25, Cross-Encoder & Financial Term Boosting""" | |
| # FAISS Retrieval | |
| query_embedding = np.array( | |
| [embed_model.encode(query, normalize_embeddings=True)], dtype="float32" | |
| ) | |
| _, faiss_indices = index.search(query_embedding, top_k) | |
| faiss_results = [chunk_texts[idx] for idx in faiss_indices[0]] | |
| # BM25 Retrieval | |
| tokenized_query = query.lower().split() | |
| bm25_scores = bm25.get_scores(tokenized_query) | |
| bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k] | |
| bm25_results = [chunk_texts[idx] for idx in bm25_top_indices] | |
| # Merge FAISS & BM25 Scores | |
| results = {} | |
| for entry in faiss_results: | |
| results[entry] = boost_score(entry, lambda_faiss) | |
| for entry in bm25_results: | |
| results[entry] = results.get(entry, 0) + boost_score(entry, (1 - lambda_faiss)) | |
| # Rank initial results | |
| retrieved_docs = sorted(results.items(), key=lambda x: x[1], reverse=True) | |
| retrieved_texts = [r[0] for r in retrieved_docs] | |
| # Cross-Encoder Re-Ranking | |
| query_text_pairs = [[query, text] for text in retrieved_texts] | |
| scores = cross_encoder.predict(query_text_pairs) | |
| ranked_indices = np.argsort(scores)[::-1] | |
| # Return top-ranked results | |
| final_results = [retrieved_texts[i] for i in ranked_indices[:top_k]] | |
| return final_results | |
| def compute_entropy(logits): | |
| """Compute entropy from logits.""" | |
| probs = torch.softmax(logits, dim=-1) | |
| log_probs = torch.log(probs + 1e-9) | |
| entropy = -(probs * log_probs).sum(dim=-1) | |
| return entropy.mean().item() | |
| def contains_future_year(query, retrieved_chunks): | |
| """Detects if the query asks for future data beyond available reports""" | |
| latest_year = get_latest_available_year(retrieved_chunks) | |
| # Extract years from query | |
| future_years = set(map(int, re.findall(r"\b(20\d{2})\b", query))) | |
| return any(year > latest_year for year in future_years) | |
| def is_explanatory_query(query): | |
| """Checks if the query requires an explanation rather than factual data""" | |
| query_lower = query.lower() | |
| return any(re.search(pattern, query_lower) for pattern in EXPLANATORY_PATTERNS) | |
| # A confidence score is computed using FAISS and BM25 ranking | |
| # FAISS: The similarity score between the response and the retrieved chunks are normalized | |
| # BM25: The BM25 scores for the query and response combined tokens is normalized | |
| # The mean of top token probability mean and 1-entropy score is the model_conf_signal | |
| # FAISS, BM25 and the model_conf_signal are combined using a weighted sum | |
| def compute_response_confidence( | |
| query, | |
| response, | |
| retrieved_chunks, | |
| bm25, | |
| model_conf_signal, | |
| lambda_faiss=0.6, | |
| lambda_conf=0.3, | |
| lambda_bm25=1.0, | |
| future_penalty=-0.3, | |
| explanation_penalty=-0.2, | |
| ): | |
| """Calculates a confidence score for the model response""" | |
| if not retrieved_chunks: | |
| return 0.0 | |
| # Compute FAISS similarity | |
| retrieved_embedding = embed_model.encode( | |
| " ".join(retrieved_chunks), normalize_embeddings=True | |
| ) | |
| response_embedding = embed_model.encode(response, normalize_embeddings=True) | |
| faiss_score = np.dot(retrieved_embedding, response_embedding) | |
| # Normalize the FAISS score | |
| normalized_faiss = (faiss_score + 1) / 2 | |
| # Compute BM25 for combined query + response | |
| tokenized_combined = (query + " " + response).lower().split() | |
| bm25_scores = bm25.get_scores(tokenized_combined) | |
| # Normalize the BM25 score | |
| if bm25_scores.size > 0: | |
| bm25_score = np.mean(bm25_scores) | |
| min_bm25, max_bm25 = np.min(bm25_scores), np.max(bm25_scores) | |
| normalized_bm25 = ( | |
| (bm25_score - min_bm25) / (max_bm25 - min_bm25 + 1e-6) | |
| if min_bm25 != max_bm25 | |
| else 0 | |
| ) | |
| normalized_bm25 = max(0, min(1, normalized_bm25)) | |
| else: | |
| normalized_bm25 = 0.0 | |
| # Penalize if query contains future years | |
| future_penalty = -0.3 if contains_future_year(query, retrieved_chunks) else 0.0 | |
| # Penalize if query is reasoning based | |
| explanation_penalty_value = ( | |
| explanation_penalty if is_explanatory_query(query) else 0.0 | |
| ) | |
| logger.info( | |
| f"Faiss score: {normalized_faiss}, BM25: {normalized_bm25}\n" | |
| f"Mean Top Token + 1-Entropy Avg: {model_conf_signal}\n" | |
| f"Future penalty: {future_penalty}, Reasoning penalty: {explanation_penalty_value}" | |
| ) | |
| # Weighted sum of all the normalized scores | |
| confidence_score = ( | |
| lambda_faiss * normalized_faiss | |
| + model_conf_signal * lambda_conf | |
| + lambda_bm25 * normalized_bm25 | |
| + future_penalty | |
| + explanation_penalty_value | |
| ) | |
| return round(min(100, max(0, confidence_score.item() * 100)), 2) | |
| # UI handle for query model button | |
| # Loads the saved FAISS embeddings and tokenized chunks for BM25 | |
| # Check the query for any policy violation | |
| # Retrieve similar texts using the RAG implementation | |
| # Prompt the loaded SLM along with the retrieved texts and compute confidence score | |
| def query_model( | |
| query, | |
| top_k=10, | |
| lambda_faiss=0.5, | |
| repetition_penalty=1.5, | |
| max_new_tokens=100, | |
| use_extraction=False, | |
| ): | |
| """Query function""" | |
| start_time = time.perf_counter() | |
| # Check if FAISS and BM25 indexes exist | |
| if not os.path.exists("data/faiss_index.bin") or not os.path.exists( | |
| "data/bm25_data.pkl" | |
| ): | |
| logger.error("No index found! Prompting user to upload PDFs.") | |
| return ( | |
| "Index files not found! Please upload PDFs first to generate embeddings.", | |
| "Error", | |
| ) | |
| allowed, reason = is_query_allowed(query) | |
| if not allowed: | |
| logger.error(f"Query Rejected: {reason}") | |
| return f"Query Rejected: {reason}", "Warning" | |
| logger.info( | |
| f"Received query: {query} | Top-K: {top_k}, " | |
| f"Lambda: {lambda_faiss}, Tokens: {max_new_tokens}" | |
| ) | |
| # Load FAISS & BM25 Indexes | |
| index = faiss.read_index("data/faiss_index.bin") | |
| with open("data/bm25_data.pkl", "rb") as f: | |
| bm25_data = pickle.load(f) | |
| # Restore tokenized chunks and metadata | |
| tokenized_chunks = bm25_data["tokenized_chunks"] | |
| chunk_texts = bm25_data["chunk_texts"] | |
| bm25 = BM25Okapi(tokenized_chunks) | |
| retrieved_chunks = hybrid_retrieve( | |
| query, chunk_texts, index, bm25, top_k=top_k, lambda_faiss=lambda_faiss | |
| ) | |
| logger.info("Retrieved chunks") | |
| context = "" | |
| token_count = 0 | |
| # context = "\n".join(retrieved_chunks) | |
| for chunk in retrieved_chunks: | |
| chunk_tokens = tokenizer(chunk, return_tensors="pt")["input_ids"].shape[1] | |
| if token_count + chunk_tokens < MAX_CONTEXT_TOKENS: | |
| context += chunk + "\n" | |
| token_count += chunk_tokens | |
| else: | |
| break | |
| prompt = ( | |
| "You are a financial analyst. Answer financial queries concisely using only the numerical data " | |
| "explicitly present in the provided financial context:\n\n" | |
| f"{context}\n\n" | |
| "Use only the given financial data—do not assume, infer, or generate missing values." | |
| " Retain the original format of financial figures without conversion." | |
| " If the requested information is unavailable, respond with 'No relevant financial data available.'" | |
| " Provide a single-sentence answer without explanations, additional text, or multiple responses." | |
| f"\nQuery: {query}" | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(DEVICE) | |
| inputs.pop("token_type_ids", None) | |
| logger.info("Generating output") | |
| input_len = inputs["input_ids"].shape[-1] | |
| logger.info(f"Input len: {input_len}") | |
| with torch.inference_mode(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| num_return_sequences=1, | |
| repetition_penalty=repetition_penalty, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| sequences = output["sequences"][0][input_len:] | |
| execution_time = time.perf_counter() - start_time | |
| logger.info(f"Query processed in {execution_time:.2f} seconds.") | |
| # Get the logits per generated token | |
| log_probs = output["scores"] | |
| token_probs = [torch.softmax(lp, dim=-1) for lp in log_probs] | |
| # Extract top token probabilities for each step | |
| token_confidences = [tp.max().item() for tp in token_probs] | |
| # Compute final confidence score | |
| top_token_conf = sum(token_confidences) / len(token_confidences) | |
| print(f"Token Token Probability Mean: {top_token_conf:.4f}") | |
| entropy_score = sum(compute_entropy(lp) for lp in log_probs) / len(log_probs) | |
| entropy_conf = 1 - (entropy_score / torch.log(torch.tensor(tokenizer.vocab_size))) | |
| print(f"Entropy-based Confidence: {entropy_conf:.4f}") | |
| model_conf_signal = (top_token_conf + (1 - entropy_conf)) / 2 | |
| response = tokenizer.decode(sequences, skip_special_tokens=True) | |
| confidence_score = compute_response_confidence( | |
| query, response, retrieved_chunks, bm25, model_conf_signal | |
| ) | |
| logger.info(f"Confidence: {confidence_score}%") | |
| if confidence_score <= 0.3: | |
| logger.error(f"The system is unsure about this response.") | |
| response += "\nThe system is unsure about this response." | |
| final_out = "" | |
| if not use_extraction: | |
| final_out += f"Context: {context}\nQuery: {query}\n" | |
| final_out += f"Response: {response}" | |
| return ( | |
| final_out, | |
| f"Confidence: {confidence_score}%\nTime taken: {execution_time:.2f} seconds", | |
| ) | |
| # Gradio UI | |
| with gr.Blocks(title="Financial Statement RAG with LLM") as ui: | |
| gr.Markdown("## Financial Statement RAG with LLM") | |
| # File upload section | |
| with gr.Group(): | |
| gr.Markdown("### Upload & Process Annual Reports") | |
| file_input = gr.File( | |
| file_count="multiple", | |
| file_types=[".pdf", ".csv"], | |
| type="filepath", | |
| label="Upload Annual Reports (PDFs/CSVs)", | |
| ) | |
| process_button = gr.Button("Process Files") | |
| process_output = gr.Textbox(label="Processing Status", interactive=False) | |
| # Query model section | |
| with gr.Group(): | |
| gr.Markdown("### Ask a Financial Query") | |
| query_input = gr.Textbox(label="Enter Query") | |
| with gr.Row(): | |
| top_k_input = gr.Number(value=15, label="Top K (Default: 15)") | |
| lambda_faiss_input = gr.Slider(0, 1, value=0.5, label="Lambda FAISS (0-1)") | |
| repetition_penalty = gr.Slider( | |
| 1, 2, value=1.2, label="Repetition Penality (1-2)" | |
| ) | |
| max_tokens_input = gr.Number(value=100, label="Max New Tokens") | |
| use_extraction = gr.Checkbox(label="Retrieve only the answer", value=False) | |
| query_button = gr.Button("Submit Query") | |
| query_output = gr.Textbox(label="Query Response", interactive=False) | |
| time_output = gr.Textbox(label="Time Taken", interactive=False) | |
| # Button Actions | |
| process_button.click(process_files, inputs=[file_input], outputs=process_output) | |
| query_button.click( | |
| query_model, | |
| inputs=[ | |
| query_input, | |
| top_k_input, | |
| lambda_faiss_input, | |
| repetition_penalty, | |
| max_tokens_input, | |
| use_extraction, | |
| ], | |
| outputs=[query_output, time_output], | |
| ) | |
| # Application entry point | |
| if __name__ == "__main__": | |
| logger.info("Starting Gradio server...") | |
| ui.launch(server_name="0.0.0.0", server_port=7860, pwa=True) | |