Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -45,7 +45,7 @@ except Exception:
|
|
| 45 |
os.makedirs("static", exist_ok=True)
|
| 46 |
os.makedirs("temp", exist_ok=True)
|
| 47 |
|
| 48 |
-
#
|
| 49 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
|
| 51 |
# Initialize FastAPI
|
|
@@ -64,16 +64,13 @@ app.add_middleware(
|
|
| 64 |
document_storage = {}
|
| 65 |
chat_history = []
|
| 66 |
|
| 67 |
-
# Function to store document context by task ID
|
| 68 |
def store_document_context(task_id, text):
|
| 69 |
document_storage[task_id] = text
|
| 70 |
return True
|
| 71 |
|
| 72 |
-
# Function to load document context by task ID
|
| 73 |
def load_document_context(task_id):
|
| 74 |
return document_storage.get(task_id, "")
|
| 75 |
|
| 76 |
-
# Utility to compute MD5 hash from file content
|
| 77 |
def compute_md5(content: bytes) -> str:
|
| 78 |
return hashlib.md5(content).hexdigest()
|
| 79 |
|
|
@@ -196,17 +193,14 @@ try:
|
|
| 196 |
spacy.cli.download("en_core_web_sm")
|
| 197 |
nlp = spacy.load("en_core_web_sm")
|
| 198 |
print("✅ Loading NLP models...")
|
| 199 |
-
# Use
|
| 200 |
summarizer = pipeline(
|
| 201 |
"summarization",
|
| 202 |
-
model="
|
| 203 |
-
tokenizer="
|
| 204 |
device=0 if torch.cuda.is_available() else -1
|
| 205 |
)
|
| 206 |
-
#
|
| 207 |
-
# if device == "cuda":
|
| 208 |
-
# summarizer.model.half()
|
| 209 |
-
|
| 210 |
embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
|
| 211 |
ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if torch.cuda.is_available() else -1)
|
| 212 |
speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
|
|
@@ -373,7 +367,10 @@ async def analyze_legal_document(file: UploadFile = File(...)):
|
|
| 373 |
if not text:
|
| 374 |
return {"status": "error", "message": "No valid text found in the document."}
|
| 375 |
summary_text = text[:4096] if len(text) > 4096 else text
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
| 377 |
entities = extract_named_entities(text)
|
| 378 |
risk_analysis = analyze_risk_enhanced(text)
|
| 379 |
clauses = analyze_contract_clauses(text)
|
|
@@ -411,7 +408,10 @@ async def analyze_legal_video(file: UploadFile = File(...), background_tasks: Ba
|
|
| 411 |
with open(transcript_path, "w") as f:
|
| 412 |
f.write(text)
|
| 413 |
summary_text = text[:4096] if len(text) > 4096 else text
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
| 415 |
entities = extract_named_entities(text)
|
| 416 |
risk_analysis = analyze_risk_enhanced(text)
|
| 417 |
clauses = analyze_contract_clauses(text)
|
|
@@ -451,7 +451,10 @@ async def analyze_legal_audio(file: UploadFile = File(...), background_tasks: Ba
|
|
| 451 |
with open(transcript_path, "w") as f:
|
| 452 |
f.write(text)
|
| 453 |
summary_text = text[:4096] if len(text) > 4096 else text
|
| 454 |
-
|
|
|
|
|
|
|
|
|
|
| 455 |
entities = extract_named_entities(text)
|
| 456 |
risk_analysis = analyze_risk_enhanced(text)
|
| 457 |
clauses = analyze_contract_clauses(text)
|
|
|
|
| 45 |
os.makedirs("static", exist_ok=True)
|
| 46 |
os.makedirs("temp", exist_ok=True)
|
| 47 |
|
| 48 |
+
# Set device to GPU if available
|
| 49 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 50 |
|
| 51 |
# Initialize FastAPI
|
|
|
|
| 64 |
document_storage = {}
|
| 65 |
chat_history = []
|
| 66 |
|
|
|
|
| 67 |
def store_document_context(task_id, text):
|
| 68 |
document_storage[task_id] = text
|
| 69 |
return True
|
| 70 |
|
|
|
|
| 71 |
def load_document_context(task_id):
|
| 72 |
return document_storage.get(task_id, "")
|
| 73 |
|
|
|
|
| 74 |
def compute_md5(content: bytes) -> str:
|
| 75 |
return hashlib.md5(content).hexdigest()
|
| 76 |
|
|
|
|
| 193 |
spacy.cli.download("en_core_web_sm")
|
| 194 |
nlp = spacy.load("en_core_web_sm")
|
| 195 |
print("✅ Loading NLP models...")
|
| 196 |
+
# Use T5-base for summarization and run it on GPU (device=0)
|
| 197 |
summarizer = pipeline(
|
| 198 |
"summarization",
|
| 199 |
+
model="t5-base",
|
| 200 |
+
tokenizer="t5-base",
|
| 201 |
device=0 if torch.cuda.is_available() else -1
|
| 202 |
)
|
| 203 |
+
# Do NOT convert the summarizer model to FP16 to reduce risk of CUDA errors
|
|
|
|
|
|
|
|
|
|
| 204 |
embedding_model = SentenceTransformer("all-mpnet-base-v2", device=device)
|
| 205 |
ner_model = pipeline("ner", model="dslim/bert-base-NER", device=0 if torch.cuda.is_available() else -1)
|
| 206 |
speech_to_text = pipeline("automatic-speech-recognition", model="openai/whisper-medium", chunk_length_s=30,
|
|
|
|
| 367 |
if not text:
|
| 368 |
return {"status": "error", "message": "No valid text found in the document."}
|
| 369 |
summary_text = text[:4096] if len(text) > 4096 else text
|
| 370 |
+
summary_result = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)
|
| 371 |
+
summary = summary_result[0].get("summary_text", "")
|
| 372 |
+
if not summary:
|
| 373 |
+
summary = "Summary not generated. Please check the input text."
|
| 374 |
entities = extract_named_entities(text)
|
| 375 |
risk_analysis = analyze_risk_enhanced(text)
|
| 376 |
clauses = analyze_contract_clauses(text)
|
|
|
|
| 408 |
with open(transcript_path, "w") as f:
|
| 409 |
f.write(text)
|
| 410 |
summary_text = text[:4096] if len(text) > 4096 else text
|
| 411 |
+
summary_result = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)
|
| 412 |
+
summary = summary_result[0].get("summary_text", "")
|
| 413 |
+
if not summary:
|
| 414 |
+
summary = "Summary not generated. Please check the input transcript."
|
| 415 |
entities = extract_named_entities(text)
|
| 416 |
risk_analysis = analyze_risk_enhanced(text)
|
| 417 |
clauses = analyze_contract_clauses(text)
|
|
|
|
| 451 |
with open(transcript_path, "w") as f:
|
| 452 |
f.write(text)
|
| 453 |
summary_text = text[:4096] if len(text) > 4096 else text
|
| 454 |
+
summary_result = summarizer(summary_text, max_length=200, min_length=50, do_sample=False)
|
| 455 |
+
summary = summary_result[0].get("summary_text", "")
|
| 456 |
+
if not summary:
|
| 457 |
+
summary = "Summary not generated. Please check the input transcript."
|
| 458 |
entities = extract_named_entities(text)
|
| 459 |
risk_analysis = analyze_risk_enhanced(text)
|
| 460 |
clauses = analyze_contract_clauses(text)
|