push
Browse files
app.py
CHANGED
|
@@ -10,7 +10,7 @@ from typing import Optional, Tuple
|
|
| 10 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
from fastapi.responses import JSONResponse
|
| 13 |
-
from transformers import
|
| 14 |
from docx import Document as DocxDocument
|
| 15 |
from pptx import Presentation
|
| 16 |
import logging
|
|
@@ -44,14 +44,15 @@ app.add_middleware(
|
|
| 44 |
allow_headers=["*"],
|
| 45 |
)
|
| 46 |
|
| 47 |
-
MODEL_ID = "
|
| 48 |
-
|
|
|
|
| 49 |
ocr_reader = None
|
| 50 |
|
| 51 |
@app.on_event("startup")
|
| 52 |
async def load_model():
|
| 53 |
-
"""Load the model
|
| 54 |
-
global
|
| 55 |
try:
|
| 56 |
logger.info(f"Loading model: {MODEL_ID} ...")
|
| 57 |
logger.info("Optimizing for CPU-only inference...")
|
|
@@ -60,19 +61,20 @@ async def load_model():
|
|
| 60 |
torch.set_num_interop_threads(os.cpu_count() or 4)
|
| 61 |
|
| 62 |
logger.info(f"Using {torch.get_num_threads()} CPU threads for inference")
|
| 63 |
-
logger.info("Loading
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
device_map="cpu",
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
}
|
| 74 |
-
)
|
| 75 |
-
logger.info("✅ Model loaded successfully in CPU RAM!")
|
| 76 |
|
| 77 |
logger.info("Loading OCR reader...")
|
| 78 |
try:
|
|
@@ -437,68 +439,46 @@ Produce ONLY valid JSON with these exact fields:
|
|
| 437 |
}}"""
|
| 438 |
|
| 439 |
try:
|
| 440 |
-
full_prompt = f"{system_message}\n\n{user_message}"
|
| 441 |
-
|
| 442 |
-
logger.info(f"Input prompt length: {len(full_prompt)} characters")
|
| 443 |
-
logger.info("Starting model generation with pipeline...")
|
| 444 |
-
|
| 445 |
-
start_time = time.time()
|
| 446 |
-
|
| 447 |
messages = [
|
| 448 |
-
{"role": "
|
|
|
|
| 449 |
]
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
generation_time = time.time() - start_time
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
try:
|
| 474 |
-
|
| 475 |
-
return parsed_json
|
| 476 |
except json.JSONDecodeError as e:
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
parsed_json = json.loads(repaired_json)
|
| 481 |
-
logger.info("✅ JSON successfully repaired")
|
| 482 |
-
return parsed_json
|
| 483 |
-
except Exception as repair_error:
|
| 484 |
-
logger.error(f"JSON repair also failed: {repair_error}")
|
| 485 |
-
logger.error(f"Problematic JSON (around error): {json_str[max(0, e.pos-200):e.pos+200]}")
|
| 486 |
-
|
| 487 |
-
try:
|
| 488 |
-
import json5
|
| 489 |
-
parsed_json = json5.loads(json_str)
|
| 490 |
-
logger.info("✅ JSON5 parsing succeeded as fallback")
|
| 491 |
-
return parsed_json
|
| 492 |
-
except ImportError:
|
| 493 |
-
pass
|
| 494 |
-
except Exception:
|
| 495 |
-
pass
|
| 496 |
-
|
| 497 |
-
raise ValueError(f"Failed to parse JSON from model output at position {e.pos}: {str(e)}. JSON preview: {json_str[max(0, e.pos-200):e.pos+200]}")
|
| 498 |
-
|
| 499 |
-
except json.JSONDecodeError as e:
|
| 500 |
-
logger.error(f"JSON parsing error: {e}")
|
| 501 |
-
raise ValueError(f"Failed to parse JSON from model output: {str(e)}")
|
| 502 |
except Exception as e:
|
| 503 |
logger.error(f"Model generation error: {e}")
|
| 504 |
raise ValueError(f"Error during model inference: {str(e)}")
|
|
@@ -522,19 +502,28 @@ Full Deck Length: {len(full_text)} characters
|
|
| 522 |
Produce a FINAL comprehensive review with the same JSON structure as before, consolidating all findings."""
|
| 523 |
|
| 524 |
try:
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
|
| 539 |
start = raw_output.find('{')
|
| 540 |
end = raw_output.rfind('}') + 1
|
|
@@ -554,11 +543,9 @@ Produce a FINAL comprehensive review with the same JSON structure as before, con
|
|
| 554 |
|
| 555 |
try:
|
| 556 |
combined_json = json.loads(json_str)
|
| 557 |
-
except json.JSONDecodeError
|
| 558 |
-
logger.warning(f"JSON parsing failed in combine, attempting repair: {e}")
|
| 559 |
try:
|
| 560 |
-
|
| 561 |
-
combined_json = json.loads(repaired_json)
|
| 562 |
except Exception:
|
| 563 |
logger.warning("JSON repair failed, returning basic structure")
|
| 564 |
return {
|
|
@@ -641,21 +628,28 @@ Return ONLY valid JSON:
|
|
| 641 |
}}"""
|
| 642 |
|
| 643 |
try:
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
|
| 660 |
start = raw_output.find('{')
|
| 661 |
end = raw_output.rfind('}') + 1
|
|
@@ -670,11 +664,9 @@ Return ONLY valid JSON:
|
|
| 670 |
|
| 671 |
try:
|
| 672 |
improvement_json = json.loads(json_str)
|
| 673 |
-
except json.JSONDecodeError
|
| 674 |
-
logger.warning(f"JSON parsing failed in improvements, attempting repair: {e}")
|
| 675 |
try:
|
| 676 |
-
|
| 677 |
-
improvement_json = json.loads(repaired_json)
|
| 678 |
except Exception:
|
| 679 |
logger.warning("JSON repair failed, returning default improvement structure")
|
| 680 |
return {
|
|
@@ -707,7 +699,7 @@ async def health():
|
|
| 707 |
"""Health check endpoint"""
|
| 708 |
return {
|
| 709 |
"status": "healthy",
|
| 710 |
-
"model_loaded":
|
| 711 |
}
|
| 712 |
|
| 713 |
@app.post("/review")
|
|
@@ -717,17 +709,26 @@ async def review_deck(file: UploadFile = File(...)):
|
|
| 717 |
|
| 718 |
Supported formats: PDF, DOCX, PPT, PPTX
|
| 719 |
"""
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
|
| 732 |
temp_file = None
|
| 733 |
try:
|
|
@@ -758,20 +759,29 @@ async def review_deck(file: UploadFile = File(...)):
|
|
| 758 |
logger.info("Review generated successfully")
|
| 759 |
|
| 760 |
logger.info("Checking if improvement pointers are needed...")
|
| 761 |
-
|
| 762 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
|
| 764 |
return JSONResponse(content=review_result)
|
| 765 |
except ValueError as e:
|
|
|
|
| 766 |
raise HTTPException(status_code=500, detail=str(e))
|
| 767 |
except Exception as e:
|
| 768 |
-
logger.error(f"Review generation error: {e}")
|
| 769 |
raise HTTPException(status_code=500, detail=f"Error generating review: {str(e)}")
|
| 770 |
|
| 771 |
except HTTPException:
|
| 772 |
raise
|
| 773 |
except Exception as e:
|
| 774 |
-
logger.error(f"Unexpected error: {e}")
|
| 775 |
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
|
| 776 |
|
| 777 |
finally:
|
|
|
|
| 10 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 11 |
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
from fastapi.responses import JSONResponse
|
| 13 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 14 |
from docx import Document as DocxDocument
|
| 15 |
from pptx import Presentation
|
| 16 |
import logging
|
|
|
|
| 44 |
allow_headers=["*"],
|
| 45 |
)
|
| 46 |
|
| 47 |
+
MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
|
| 48 |
+
model = None
|
| 49 |
+
tokenizer = None
|
| 50 |
ocr_reader = None
|
| 51 |
|
| 52 |
@app.on_event("startup")
|
| 53 |
async def load_model():
|
| 54 |
+
"""Load the Zephyr tokenizer/model and OCR reader on startup"""
|
| 55 |
+
global tokenizer, model, ocr_reader
|
| 56 |
try:
|
| 57 |
logger.info(f"Loading model: {MODEL_ID} ...")
|
| 58 |
logger.info("Optimizing for CPU-only inference...")
|
|
|
|
| 61 |
torch.set_num_interop_threads(os.cpu_count() or 4)
|
| 62 |
|
| 63 |
logger.info(f"Using {torch.get_num_threads()} CPU threads for inference")
|
| 64 |
+
logger.info("Loading Zephyr tokenizer and model (CPU)...")
|
| 65 |
+
|
| 66 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 67 |
+
if tokenizer.pad_token is None:
|
| 68 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 69 |
+
|
| 70 |
+
desired_dtype = torch.bfloat16 if hasattr(torch, "cpu") and getattr(torch.cpu, "is_bf16_supported", lambda: False)() else torch.float32
|
| 71 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 72 |
+
MODEL_ID,
|
| 73 |
+
dtype=desired_dtype,
|
| 74 |
device_map="cpu",
|
| 75 |
+
low_cpu_mem_usage=False
|
| 76 |
+
).eval()
|
| 77 |
+
logger.info("✅ Zephyr loaded successfully on CPU!")
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
logger.info("Loading OCR reader...")
|
| 80 |
try:
|
|
|
|
| 439 |
}}"""
|
| 440 |
|
| 441 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
messages = [
|
| 443 |
+
{"role": "system", "content": system_message},
|
| 444 |
+
{"role": "user", "content": user_message}
|
| 445 |
]
|
| 446 |
+
|
| 447 |
+
start_time = time.time()
|
| 448 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 449 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=3800).to(model.device)
|
| 450 |
+
|
| 451 |
+
with torch.no_grad():
|
| 452 |
+
outputs = model.generate(
|
| 453 |
+
**inputs,
|
| 454 |
+
max_new_tokens=800,
|
| 455 |
+
temperature=0.2,
|
| 456 |
+
top_p=0.9,
|
| 457 |
+
do_sample=True,
|
| 458 |
+
repetition_penalty=1.08,
|
| 459 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 460 |
+
use_cache=True
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
generation_time = time.time() - start_time
|
| 464 |
+
raw_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 465 |
+
if "<|assistant|>" in raw_text:
|
| 466 |
+
raw_text = raw_text.split("<|assistant|>")[-1]
|
| 467 |
+
|
| 468 |
+
logger.info(f"✅ Generated {len(raw_text)} chars in {generation_time:.2f}s")
|
| 469 |
+
|
| 470 |
+
start = raw_text.find('{')
|
| 471 |
+
end = raw_text.rfind('}') + 1
|
| 472 |
+
if start == -1 or end <= 0:
|
| 473 |
+
raise ValueError("No JSON object found in model output")
|
| 474 |
+
|
| 475 |
+
json_str = raw_text[start:end]
|
| 476 |
try:
|
| 477 |
+
return json.loads(json_str)
|
|
|
|
| 478 |
except json.JSONDecodeError as e:
|
| 479 |
+
repaired = _repair_json(json_str)
|
| 480 |
+
return json.loads(repaired)
|
| 481 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
except Exception as e:
|
| 483 |
logger.error(f"Model generation error: {e}")
|
| 484 |
raise ValueError(f"Error during model inference: {str(e)}")
|
|
|
|
| 502 |
Produce a FINAL comprehensive review with the same JSON structure as before, consolidating all findings."""
|
| 503 |
|
| 504 |
try:
|
| 505 |
+
messages = [
|
| 506 |
+
{"role": "system", "content": system_message},
|
| 507 |
+
{"role": "user", "content": user_message}
|
| 508 |
+
]
|
| 509 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 510 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=3800).to(model.device)
|
| 511 |
+
|
| 512 |
+
with torch.no_grad():
|
| 513 |
+
outputs = model.generate(
|
| 514 |
+
**inputs,
|
| 515 |
+
max_new_tokens=800,
|
| 516 |
+
temperature=0.2,
|
| 517 |
+
top_p=0.9,
|
| 518 |
+
do_sample=True,
|
| 519 |
+
repetition_penalty=1.05,
|
| 520 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 521 |
+
use_cache=True
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 525 |
+
if "<|assistant|>" in raw_output:
|
| 526 |
+
raw_output = raw_output.split("<|assistant|>")[-1]
|
| 527 |
|
| 528 |
start = raw_output.find('{')
|
| 529 |
end = raw_output.rfind('}') + 1
|
|
|
|
| 543 |
|
| 544 |
try:
|
| 545 |
combined_json = json.loads(json_str)
|
| 546 |
+
except json.JSONDecodeError:
|
|
|
|
| 547 |
try:
|
| 548 |
+
combined_json = json.loads(_repair_json(json_str))
|
|
|
|
| 549 |
except Exception:
|
| 550 |
logger.warning("JSON repair failed, returning basic structure")
|
| 551 |
return {
|
|
|
|
| 628 |
}}"""
|
| 629 |
|
| 630 |
try:
|
| 631 |
+
messages = [
|
| 632 |
+
{"role": "system", "content": system_message},
|
| 633 |
+
{"role": "user", "content": user_message}
|
| 634 |
+
]
|
| 635 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 636 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=3600).to(model.device)
|
| 637 |
+
|
| 638 |
+
with torch.no_grad():
|
| 639 |
+
outputs = model.generate(
|
| 640 |
+
**inputs,
|
| 641 |
+
max_new_tokens=600,
|
| 642 |
+
temperature=0.25,
|
| 643 |
+
top_p=0.9,
|
| 644 |
+
do_sample=True,
|
| 645 |
+
repetition_penalty=1.05,
|
| 646 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 647 |
+
use_cache=True
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 651 |
+
if "<|assistant|>" in raw_output:
|
| 652 |
+
raw_output = raw_output.split("<|assistant|>")[-1]
|
| 653 |
|
| 654 |
start = raw_output.find('{')
|
| 655 |
end = raw_output.rfind('}') + 1
|
|
|
|
| 664 |
|
| 665 |
try:
|
| 666 |
improvement_json = json.loads(json_str)
|
| 667 |
+
except json.JSONDecodeError:
|
|
|
|
| 668 |
try:
|
| 669 |
+
improvement_json = json.loads(_repair_json(json_str))
|
|
|
|
| 670 |
except Exception:
|
| 671 |
logger.warning("JSON repair failed, returning default improvement structure")
|
| 672 |
return {
|
|
|
|
| 699 |
"""Health check endpoint"""
|
| 700 |
return {
|
| 701 |
"status": "healthy",
|
| 702 |
+
"model_loaded": (model is not None and tokenizer is not None)
|
| 703 |
}
|
| 704 |
|
| 705 |
@app.post("/review")
|
|
|
|
| 709 |
|
| 710 |
Supported formats: PDF, DOCX, PPT, PPTX
|
| 711 |
"""
|
| 712 |
+
try:
|
| 713 |
+
if model is None or tokenizer is None:
|
| 714 |
+
raise HTTPException(status_code=503, detail="Model not loaded yet. Please wait for startup to complete.")
|
| 715 |
+
|
| 716 |
+
if not file.filename:
|
| 717 |
+
raise HTTPException(status_code=400, detail="Filename is missing")
|
| 718 |
+
|
| 719 |
+
file_extension = Path(file.filename).suffix.lower()
|
| 720 |
+
supported_extensions = [".pdf", ".docx", ".doc", ".ppt", ".pptx"]
|
| 721 |
+
|
| 722 |
+
if file_extension not in supported_extensions:
|
| 723 |
+
raise HTTPException(
|
| 724 |
+
status_code=400,
|
| 725 |
+
detail=f"Unsupported file type: {file_extension}. Supported: {', '.join(supported_extensions)}"
|
| 726 |
+
)
|
| 727 |
+
except HTTPException:
|
| 728 |
+
raise
|
| 729 |
+
except Exception as e:
|
| 730 |
+
logger.error(f"Error in request validation: {e}", exc_info=True)
|
| 731 |
+
raise HTTPException(status_code=500, detail=f"Request validation error: {str(e)}")
|
| 732 |
|
| 733 |
temp_file = None
|
| 734 |
try:
|
|
|
|
| 759 |
logger.info("Review generated successfully")
|
| 760 |
|
| 761 |
logger.info("Checking if improvement pointers are needed...")
|
| 762 |
+
try:
|
| 763 |
+
improvement_pointers = generate_improvement_pointers(review_result)
|
| 764 |
+
review_result["improvement_analysis"] = improvement_pointers
|
| 765 |
+
except Exception as imp_error:
|
| 766 |
+
logger.warning(f"Improvement pointers generation failed: {imp_error}, continuing without it")
|
| 767 |
+
review_result["improvement_analysis"] = {
|
| 768 |
+
"needs_improvement": True,
|
| 769 |
+
"improvement_pointers": [],
|
| 770 |
+
"error": "Failed to generate improvement pointers"
|
| 771 |
+
}
|
| 772 |
|
| 773 |
return JSONResponse(content=review_result)
|
| 774 |
except ValueError as e:
|
| 775 |
+
logger.error(f"ValueError in review generation: {e}", exc_info=True)
|
| 776 |
raise HTTPException(status_code=500, detail=str(e))
|
| 777 |
except Exception as e:
|
| 778 |
+
logger.error(f"Review generation error: {e}", exc_info=True)
|
| 779 |
raise HTTPException(status_code=500, detail=f"Error generating review: {str(e)}")
|
| 780 |
|
| 781 |
except HTTPException:
|
| 782 |
raise
|
| 783 |
except Exception as e:
|
| 784 |
+
logger.error(f"Unexpected error in review endpoint: {e}", exc_info=True)
|
| 785 |
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
|
| 786 |
|
| 787 |
finally:
|