Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import warnings | |
| import vertexai | |
| from transformers import AutoTokenizer, pipeline, AutoModelForSeq2SeqLM | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFacePipeline | |
| from vertexai.generative_models import GenerativeModel, GenerationConfig | |
| import gradio as gr | |
| from fastapi import FastAPI, UploadFile, File | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| warnings.filterwarnings("ignore") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| FILTER_MODEL_ID = os.getenv("FILTER_MODEL", "google/flan-t5-small") | |
| GEMINI_MODEL = os.getenv("GEMINI_MODEL", "gemini-2.0-flash-001") | |
| PROJECT_ID = os.getenv("GCP_PROJECT_ID", "gahld-469906") | |
| LOCATION = os.getenv("GCP_REGION", "us-central1") | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", 512)) | |
| TEMPERATURE = float(os.getenv("TEMPERATURE", 0.2)) | |
| TOP_P = float(os.getenv("TOP_P", 0.95)) | |
| def load_filter_model(model_id=FILTER_MODEL_ID): | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| text_gen = pipeline( | |
| "text2text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| ) | |
| return HuggingFacePipeline(pipeline=text_gen) | |
| filter_llm = load_filter_model() | |
| import os, json | |
| from google.oauth2 import service_account | |
| from google.cloud import aiplatform | |
| creds_env = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") | |
| if creds_env and creds_env.strip().startswith("{"): | |
| creds_info = json.loads(creds_env) | |
| creds = service_account.Credentials.from_service_account_info(creds_info) | |
| else: | |
| creds = service_account.Credentials.from_service_account_file(creds_env) | |
| aiplatform.init(project=PROJECT_ID, location=LOCATION, credentials=creds) | |
| vertexai.init(project=PROJECT_ID, location=LOCATION, credentials=creds) | |
| gemini_model = GenerativeModel(GEMINI_MODEL) | |
| gen_config = GenerationConfig( | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| max_output_tokens=8192, | |
| ) | |
| def pii_filter(chunk: str): | |
| prompt = f""" | |
| You are a privacy filter. Review the following text and redact or anonymize any sensitive data | |
| (names, emails, phone numbers, IDs, addresses, or any personally identifiable information). | |
| Keep the meaning intact. | |
| If the chunk only contains a section title like "1. Bank Policies", return it unchanged. | |
| Text: | |
| {chunk} | |
| Return the safe-to-share version: | |
| """ | |
| result = filter_llm(prompt) | |
| return result | |
| def summarizer(texts): | |
| SUMMARY_PROMPT = """You are a helpful legal assistant to help summarize the anonymized document in this prompt. | |
| Structure the output as follows: | |
| **SUMMARY** | |
| 1) TLDR (2-3 lines) | |
| 2) Key Policies | |
| 3) Prohibited Actions | |
| 4) Exceptions/Notes (if necessary) | |
| 5) Clauses Mentioned (if necessary) | |
| """ | |
| joined_text = "\n\n".join([t.page_content for t in texts]) | |
| response = gemini_model.generate_content( | |
| SUMMARY_PROMPT + "\n\nDocument:\n" + joined_text, | |
| generation_config=gen_config, | |
| ) | |
| return response.text | |
| def summarize_document(file_path: str): | |
| loader = TextLoader(file_path, encoding="utf-8") | |
| documents = loader.load() | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=7000, | |
| chunk_overlap=50, | |
| separators=["\n\n", "\n", " ", ""] | |
| ) | |
| texts = text_splitter.split_documents(documents) | |
| safe_chunks = [] | |
| for doc in texts: | |
| safe_text = pii_filter(doc.page_content) | |
| doc.page_content = safe_text | |
| safe_chunks.append(doc) | |
| summary = summarizer(safe_chunks) | |
| return summary | |
| # def process_file(file_obj): | |
| # if file_obj is None: | |
| # return "Please upload a document." | |
| # return summarize_document(file_obj.name) | |
| # demo = gr.Interface( | |
| # fn=process_file, | |
| # inputs=gr.File(label="Upload a document", file_types=[".txt", ".md", ".pdf"]), | |
| # outputs="text", | |
| # title="Legal Document Summarizer", | |
| # description="Upload a document. The app will anonymize sensitive info and summarize it." | |
| # ) | |
| app = FastAPI() | |
| # app = gr.mount_gradio_app(app, demo, path="/") | |
| async def summarize_api(file: UploadFile = File(...)): | |
| file_path = f"temp_{file.filename}" | |
| with open(file_path, "wb") as f: | |
| f.write(await file.read()) | |
| summary = summarize_document(file_path) | |
| os.remove(file_path) | |
| return {"summary": summary} | |
| def read_root(): | |
| return {"message": "API is running."} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |