import os import torch import warnings import vertexai from transformers import AutoTokenizer, pipeline, AutoModelForSeq2SeqLM from langchain_community.document_loaders import TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFacePipeline from vertexai.generative_models import GenerativeModel, GenerationConfig import gradio as gr from fastapi import FastAPI, UploadFile, File from pydantic import BaseModel import uvicorn from dotenv import load_dotenv load_dotenv() warnings.filterwarnings("ignore") HF_TOKEN = os.getenv("HF_TOKEN") FILTER_MODEL_ID = os.getenv("FILTER_MODEL", "google/flan-t5-small") GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.0-flash-001") PROJECT_ID = os.getenv("GCP_PROJECT_ID", "gahld-469906") LOCATION = os.getenv("GCP_REGION", "us-central1") MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", 512)) TEMPERATURE = float(os.getenv("TEMPERATURE", 0.2)) TOP_P = float(os.getenv("TOP_P", 0.95)) def load_filter_model(model_id=FILTER_MODEL_ID): model = AutoModelForSeq2SeqLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) text_gen = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_new_tokens=MAX_NEW_TOKENS, ) return HuggingFacePipeline(pipeline=text_gen) filter_llm = load_filter_model() import os, json from google.oauth2 import service_account from google.cloud import aiplatform creds_env = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") if creds_env and creds_env.strip().startswith("{"): creds_info = json.loads(creds_env) creds = service_account.Credentials.from_service_account_info(creds_info) else: creds = service_account.Credentials.from_service_account_file(creds_env) aiplatform.init(project=PROJECT_ID, location=LOCATION, credentials=creds) vertexai.init(project=PROJECT_ID, location=LOCATION, credentials=creds) gemini_model = GenerativeModel(GEMINI_MODEL) gen_config = GenerationConfig( temperature=TEMPERATURE, top_p=TOP_P, max_output_tokens=8192, ) def pii_filter(chunk: str): prompt = f""" You are a privacy filter. Review the following text and redact or anonymize any sensitive data (names, emails, phone numbers, IDs, addresses, or any personally identifiable information). Keep the meaning intact. If the chunk only contains a section title like "1. Bank Policies", return it unchanged. Text: {chunk} Return the safe-to-share version: """ result = filter_llm(prompt) return result def summarizer(texts): SUMMARY_PROMPT = """You are a helpful legal assistant to help summarize the anonymized document in this prompt. Structure the output as follows: **SUMMARY** 1) TLDR (2-3 lines) 2) Key Policies 3) Prohibited Actions 4) Exceptions/Notes (if necessary) 5) Clauses Mentioned (if necessary) """ joined_text = "\n\n".join([t.page_content for t in texts]) response = gemini_model.generate_content( SUMMARY_PROMPT + "\n\nDocument:\n" + joined_text, generation_config=gen_config, ) return response.text def summarize_document(file_path: str): loader = TextLoader(file_path, encoding="utf-8") documents = loader.load() text_splitter = RecursiveCharacterTextSplitter( chunk_size=7000, chunk_overlap=50, separators=["\n\n", "\n", " ", ""] ) texts = text_splitter.split_documents(documents) safe_chunks = [] for doc in texts: safe_text = pii_filter(doc.page_content) doc.page_content = safe_text safe_chunks.append(doc) summary = summarizer(safe_chunks) return summary # def process_file(file_obj): # if file_obj is None: # return "Please upload a document." # return summarize_document(file_obj.name) # demo = gr.Interface( # fn=process_file, # inputs=gr.File(label="Upload a document", file_types=[".txt", ".md", ".pdf"]), # outputs="text", # title="Legal Document Summarizer", # description="Upload a document. The app will anonymize sensitive info and summarize it." # ) app = FastAPI() # app = gr.mount_gradio_app(app, demo, path="/") @app.post("/summarize/") async def summarize_api(file: UploadFile = File(...)): file_path = f"temp_{file.filename}" with open(file_path, "wb") as f: f.write(await file.read()) summary = summarize_document(file_path) os.remove(file_path) return {"summary": summary} @app.get("/") def read_root(): return {"message": "API is running."} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)