nexusbert commited on
Commit
cb4021c
·
1 Parent(s): 86b5a56
Files changed (1) hide show
  1. app.py +136 -126
app.py CHANGED
@@ -10,7 +10,7 @@ from typing import Optional, Tuple
10
  from fastapi import FastAPI, UploadFile, File, HTTPException
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi.responses import JSONResponse
13
- from transformers import pipeline
14
  from docx import Document as DocxDocument
15
  from pptx import Presentation
16
  import logging
@@ -44,14 +44,15 @@ app.add_middleware(
44
  allow_headers=["*"],
45
  )
46
 
47
- MODEL_ID = "tiiuae/Falcon3-3B-Instruct"
48
- pipe = None
 
49
  ocr_reader = None
50
 
51
  @app.on_event("startup")
52
  async def load_model():
53
- """Load the model pipeline and OCR reader on startup"""
54
- global pipe, ocr_reader
55
  try:
56
  logger.info(f"Loading model: {MODEL_ID} ...")
57
  logger.info("Optimizing for CPU-only inference...")
@@ -60,19 +61,20 @@ async def load_model():
60
  torch.set_num_interop_threads(os.cpu_count() or 4)
61
 
62
  logger.info(f"Using {torch.get_num_threads()} CPU threads for inference")
63
- logger.info("Loading full model into CPU RAM (no offloading)...")
64
-
65
- pipe = pipeline(
66
- "text-generation",
67
- model=MODEL_ID,
68
- dtype=torch.bfloat16,
 
 
 
 
69
  device_map="cpu",
70
- model_kwargs={
71
- "low_cpu_mem_usage": False,
72
- "offload_folder": None
73
- }
74
- )
75
- logger.info("✅ Model loaded successfully in CPU RAM!")
76
 
77
  logger.info("Loading OCR reader...")
78
  try:
@@ -437,68 +439,46 @@ Produce ONLY valid JSON with these exact fields:
437
  }}"""
438
 
439
  try:
440
- full_prompt = f"{system_message}\n\n{user_message}"
441
-
442
- logger.info(f"Input prompt length: {len(full_prompt)} characters")
443
- logger.info("Starting model generation with pipeline...")
444
-
445
- start_time = time.time()
446
-
447
  messages = [
448
- {"role": "user", "content": full_prompt}
 
449
  ]
450
-
451
- result = pipe(
452
- messages,
453
- max_new_tokens=800,
454
- temperature=0.2,
455
- do_sample=True,
456
- top_p=0.9,
457
- return_full_text=False
458
- )
459
-
 
 
 
 
 
 
 
460
  generation_time = time.time() - start_time
461
- raw_output = result[0]["generated_text"]
462
- logger.info(f" Generated {len(raw_output)} characters in {generation_time:.2f}s ({len(raw_output)/generation_time:.1f} chars/sec)")
463
-
464
- start = raw_output.find('{')
465
- end = raw_output.rfind('}') + 1
466
-
467
- if start == -1 or end == 0:
468
- logger.warning("No JSON found in output, returning raw output")
469
- raise ValueError(f"No JSON object found in model output. Raw output: {raw_output[:500]}")
470
-
471
- json_str = raw_output[start:end]
472
-
473
  try:
474
- parsed_json = json.loads(json_str)
475
- return parsed_json
476
  except json.JSONDecodeError as e:
477
- logger.warning(f"JSON parsing failed, attempting to repair: {e}")
478
- try:
479
- repaired_json = _repair_json(json_str)
480
- parsed_json = json.loads(repaired_json)
481
- logger.info("✅ JSON successfully repaired")
482
- return parsed_json
483
- except Exception as repair_error:
484
- logger.error(f"JSON repair also failed: {repair_error}")
485
- logger.error(f"Problematic JSON (around error): {json_str[max(0, e.pos-200):e.pos+200]}")
486
-
487
- try:
488
- import json5
489
- parsed_json = json5.loads(json_str)
490
- logger.info("✅ JSON5 parsing succeeded as fallback")
491
- return parsed_json
492
- except ImportError:
493
- pass
494
- except Exception:
495
- pass
496
-
497
- raise ValueError(f"Failed to parse JSON from model output at position {e.pos}: {str(e)}. JSON preview: {json_str[max(0, e.pos-200):e.pos+200]}")
498
-
499
- except json.JSONDecodeError as e:
500
- logger.error(f"JSON parsing error: {e}")
501
- raise ValueError(f"Failed to parse JSON from model output: {str(e)}")
502
  except Exception as e:
503
  logger.error(f"Model generation error: {e}")
504
  raise ValueError(f"Error during model inference: {str(e)}")
@@ -522,19 +502,28 @@ Full Deck Length: {len(full_text)} characters
522
  Produce a FINAL comprehensive review with the same JSON structure as before, consolidating all findings."""
523
 
524
  try:
525
- full_prompt = f"{system_message}\n\n{user_message}"
526
- messages = [{"role": "user", "content": full_prompt}]
527
-
528
- result = pipe(
529
- messages,
530
- max_new_tokens=800,
531
- temperature=0.2,
532
- do_sample=True,
533
- top_p=0.9,
534
- return_full_text=False
535
- )
536
-
537
- raw_output = result[0]["generated_text"]
 
 
 
 
 
 
 
 
 
538
 
539
  start = raw_output.find('{')
540
  end = raw_output.rfind('}') + 1
@@ -554,11 +543,9 @@ Produce a FINAL comprehensive review with the same JSON structure as before, con
554
 
555
  try:
556
  combined_json = json.loads(json_str)
557
- except json.JSONDecodeError as e:
558
- logger.warning(f"JSON parsing failed in combine, attempting repair: {e}")
559
  try:
560
- repaired_json = _repair_json(json_str)
561
- combined_json = json.loads(repaired_json)
562
  except Exception:
563
  logger.warning("JSON repair failed, returning basic structure")
564
  return {
@@ -641,21 +628,28 @@ Return ONLY valid JSON:
641
  }}"""
642
 
643
  try:
644
- full_prompt = f"{system_message}\n\n{user_message}"
645
- messages = [{"role": "user", "content": full_prompt}]
646
-
647
- result = pipe(
648
- messages,
649
- max_new_tokens=600,
650
- temperature=0.25,
651
- do_sample=True,
652
- top_p=0.9,
653
- return_full_text=False,
654
- pad_token_id=None,
655
- eos_token_id=None
656
- )
657
-
658
- raw_output = result[0]["generated_text"]
 
 
 
 
 
 
 
659
 
660
  start = raw_output.find('{')
661
  end = raw_output.rfind('}') + 1
@@ -670,11 +664,9 @@ Return ONLY valid JSON:
670
 
671
  try:
672
  improvement_json = json.loads(json_str)
673
- except json.JSONDecodeError as e:
674
- logger.warning(f"JSON parsing failed in improvements, attempting repair: {e}")
675
  try:
676
- repaired_json = _repair_json(json_str)
677
- improvement_json = json.loads(repaired_json)
678
  except Exception:
679
  logger.warning("JSON repair failed, returning default improvement structure")
680
  return {
@@ -707,7 +699,7 @@ async def health():
707
  """Health check endpoint"""
708
  return {
709
  "status": "healthy",
710
- "model_loaded": pipe is not None
711
  }
712
 
713
  @app.post("/review")
@@ -717,17 +709,26 @@ async def review_deck(file: UploadFile = File(...)):
717
 
718
  Supported formats: PDF, DOCX, PPT, PPTX
719
  """
720
- if pipe is None:
721
- raise HTTPException(status_code=503, detail="Model not loaded yet. Please wait for startup to complete.")
722
-
723
- file_extension = Path(file.filename).suffix.lower()
724
- supported_extensions = [".pdf", ".docx", ".doc", ".ppt", ".pptx"]
725
-
726
- if file_extension not in supported_extensions:
727
- raise HTTPException(
728
- status_code=400,
729
- detail=f"Unsupported file type: {file_extension}. Supported: {', '.join(supported_extensions)}"
730
- )
 
 
 
 
 
 
 
 
 
731
 
732
  temp_file = None
733
  try:
@@ -758,20 +759,29 @@ async def review_deck(file: UploadFile = File(...)):
758
  logger.info("Review generated successfully")
759
 
760
  logger.info("Checking if improvement pointers are needed...")
761
- improvement_pointers = generate_improvement_pointers(review_result)
762
- review_result["improvement_analysis"] = improvement_pointers
 
 
 
 
 
 
 
 
763
 
764
  return JSONResponse(content=review_result)
765
  except ValueError as e:
 
766
  raise HTTPException(status_code=500, detail=str(e))
767
  except Exception as e:
768
- logger.error(f"Review generation error: {e}")
769
  raise HTTPException(status_code=500, detail=f"Error generating review: {str(e)}")
770
 
771
  except HTTPException:
772
  raise
773
  except Exception as e:
774
- logger.error(f"Unexpected error: {e}")
775
  raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
776
 
777
  finally:
 
10
  from fastapi import FastAPI, UploadFile, File, HTTPException
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi.responses import JSONResponse
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
  from docx import Document as DocxDocument
15
  from pptx import Presentation
16
  import logging
 
44
  allow_headers=["*"],
45
  )
46
 
47
+ MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
48
+ model = None
49
+ tokenizer = None
50
  ocr_reader = None
51
 
52
  @app.on_event("startup")
53
  async def load_model():
54
+ """Load the Zephyr tokenizer/model and OCR reader on startup"""
55
+ global tokenizer, model, ocr_reader
56
  try:
57
  logger.info(f"Loading model: {MODEL_ID} ...")
58
  logger.info("Optimizing for CPU-only inference...")
 
61
  torch.set_num_interop_threads(os.cpu_count() or 4)
62
 
63
  logger.info(f"Using {torch.get_num_threads()} CPU threads for inference")
64
+ logger.info("Loading Zephyr tokenizer and model (CPU)...")
65
+
66
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
67
+ if tokenizer.pad_token is None:
68
+ tokenizer.pad_token = tokenizer.eos_token
69
+
70
+ desired_dtype = torch.bfloat16 if hasattr(torch, "cpu") and getattr(torch.cpu, "is_bf16_supported", lambda: False)() else torch.float32
71
+ model = AutoModelForCausalLM.from_pretrained(
72
+ MODEL_ID,
73
+ dtype=desired_dtype,
74
  device_map="cpu",
75
+ low_cpu_mem_usage=False
76
+ ).eval()
77
+ logger.info("✅ Zephyr loaded successfully on CPU!")
 
 
 
78
 
79
  logger.info("Loading OCR reader...")
80
  try:
 
439
  }}"""
440
 
441
  try:
 
 
 
 
 
 
 
442
  messages = [
443
+ {"role": "system", "content": system_message},
444
+ {"role": "user", "content": user_message}
445
  ]
446
+
447
+ start_time = time.time()
448
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
449
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=3800).to(model.device)
450
+
451
+ with torch.no_grad():
452
+ outputs = model.generate(
453
+ **inputs,
454
+ max_new_tokens=800,
455
+ temperature=0.2,
456
+ top_p=0.9,
457
+ do_sample=True,
458
+ repetition_penalty=1.08,
459
+ pad_token_id=tokenizer.eos_token_id,
460
+ use_cache=True
461
+ )
462
+
463
  generation_time = time.time() - start_time
464
+ raw_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
465
+ if "<|assistant|>" in raw_text:
466
+ raw_text = raw_text.split("<|assistant|>")[-1]
467
+
468
+ logger.info(f"✅ Generated {len(raw_text)} chars in {generation_time:.2f}s")
469
+
470
+ start = raw_text.find('{')
471
+ end = raw_text.rfind('}') + 1
472
+ if start == -1 or end <= 0:
473
+ raise ValueError("No JSON object found in model output")
474
+
475
+ json_str = raw_text[start:end]
476
  try:
477
+ return json.loads(json_str)
 
478
  except json.JSONDecodeError as e:
479
+ repaired = _repair_json(json_str)
480
+ return json.loads(repaired)
481
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
  except Exception as e:
483
  logger.error(f"Model generation error: {e}")
484
  raise ValueError(f"Error during model inference: {str(e)}")
 
502
  Produce a FINAL comprehensive review with the same JSON structure as before, consolidating all findings."""
503
 
504
  try:
505
+ messages = [
506
+ {"role": "system", "content": system_message},
507
+ {"role": "user", "content": user_message}
508
+ ]
509
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
510
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=3800).to(model.device)
511
+
512
+ with torch.no_grad():
513
+ outputs = model.generate(
514
+ **inputs,
515
+ max_new_tokens=800,
516
+ temperature=0.2,
517
+ top_p=0.9,
518
+ do_sample=True,
519
+ repetition_penalty=1.05,
520
+ pad_token_id=tokenizer.eos_token_id,
521
+ use_cache=True
522
+ )
523
+
524
+ raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
525
+ if "<|assistant|>" in raw_output:
526
+ raw_output = raw_output.split("<|assistant|>")[-1]
527
 
528
  start = raw_output.find('{')
529
  end = raw_output.rfind('}') + 1
 
543
 
544
  try:
545
  combined_json = json.loads(json_str)
546
+ except json.JSONDecodeError:
 
547
  try:
548
+ combined_json = json.loads(_repair_json(json_str))
 
549
  except Exception:
550
  logger.warning("JSON repair failed, returning basic structure")
551
  return {
 
628
  }}"""
629
 
630
  try:
631
+ messages = [
632
+ {"role": "system", "content": system_message},
633
+ {"role": "user", "content": user_message}
634
+ ]
635
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
636
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=3600).to(model.device)
637
+
638
+ with torch.no_grad():
639
+ outputs = model.generate(
640
+ **inputs,
641
+ max_new_tokens=600,
642
+ temperature=0.25,
643
+ top_p=0.9,
644
+ do_sample=True,
645
+ repetition_penalty=1.05,
646
+ pad_token_id=tokenizer.eos_token_id,
647
+ use_cache=True
648
+ )
649
+
650
+ raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
651
+ if "<|assistant|>" in raw_output:
652
+ raw_output = raw_output.split("<|assistant|>")[-1]
653
 
654
  start = raw_output.find('{')
655
  end = raw_output.rfind('}') + 1
 
664
 
665
  try:
666
  improvement_json = json.loads(json_str)
667
+ except json.JSONDecodeError:
 
668
  try:
669
+ improvement_json = json.loads(_repair_json(json_str))
 
670
  except Exception:
671
  logger.warning("JSON repair failed, returning default improvement structure")
672
  return {
 
699
  """Health check endpoint"""
700
  return {
701
  "status": "healthy",
702
+ "model_loaded": (model is not None and tokenizer is not None)
703
  }
704
 
705
  @app.post("/review")
 
709
 
710
  Supported formats: PDF, DOCX, PPT, PPTX
711
  """
712
+ try:
713
+ if model is None or tokenizer is None:
714
+ raise HTTPException(status_code=503, detail="Model not loaded yet. Please wait for startup to complete.")
715
+
716
+ if not file.filename:
717
+ raise HTTPException(status_code=400, detail="Filename is missing")
718
+
719
+ file_extension = Path(file.filename).suffix.lower()
720
+ supported_extensions = [".pdf", ".docx", ".doc", ".ppt", ".pptx"]
721
+
722
+ if file_extension not in supported_extensions:
723
+ raise HTTPException(
724
+ status_code=400,
725
+ detail=f"Unsupported file type: {file_extension}. Supported: {', '.join(supported_extensions)}"
726
+ )
727
+ except HTTPException:
728
+ raise
729
+ except Exception as e:
730
+ logger.error(f"Error in request validation: {e}", exc_info=True)
731
+ raise HTTPException(status_code=500, detail=f"Request validation error: {str(e)}")
732
 
733
  temp_file = None
734
  try:
 
759
  logger.info("Review generated successfully")
760
 
761
  logger.info("Checking if improvement pointers are needed...")
762
+ try:
763
+ improvement_pointers = generate_improvement_pointers(review_result)
764
+ review_result["improvement_analysis"] = improvement_pointers
765
+ except Exception as imp_error:
766
+ logger.warning(f"Improvement pointers generation failed: {imp_error}, continuing without it")
767
+ review_result["improvement_analysis"] = {
768
+ "needs_improvement": True,
769
+ "improvement_pointers": [],
770
+ "error": "Failed to generate improvement pointers"
771
+ }
772
 
773
  return JSONResponse(content=review_result)
774
  except ValueError as e:
775
+ logger.error(f"ValueError in review generation: {e}", exc_info=True)
776
  raise HTTPException(status_code=500, detail=str(e))
777
  except Exception as e:
778
+ logger.error(f"Review generation error: {e}", exc_info=True)
779
  raise HTTPException(status_code=500, detail=f"Error generating review: {str(e)}")
780
 
781
  except HTTPException:
782
  raise
783
  except Exception as e:
784
+ logger.error(f"Unexpected error in review endpoint: {e}", exc_info=True)
785
  raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
786
 
787
  finally: