push new model
Browse files- Dockerfile +2 -1
- app.py +38 -107
Dockerfile
CHANGED
|
@@ -38,7 +38,8 @@ ENV HF_HOME=/models/huggingface \
|
|
| 38 |
# Create cache dir and set permissions
|
| 39 |
RUN mkdir -p /models/huggingface && chmod -R 777 /models/huggingface
|
| 40 |
|
| 41 |
-
#
|
|
|
|
| 42 |
|
| 43 |
# Copy project files
|
| 44 |
COPY . .
|
|
|
|
| 38 |
# Create cache dir and set permissions
|
| 39 |
RUN mkdir -p /models/huggingface && chmod -R 777 /models/huggingface
|
| 40 |
|
| 41 |
+
# Pre-download the model during build
|
| 42 |
+
RUN python -c "from transformers import pipeline; import torch; pipe = pipeline('text-generation', model='tiiuae/Falcon3-3B-Instruct', torch_dtype=torch.bfloat16, device_map='cpu')" || true
|
| 43 |
|
| 44 |
# Copy project files
|
| 45 |
COPY . .
|
app.py
CHANGED
|
@@ -9,7 +9,7 @@ from typing import Optional, Tuple
|
|
| 9 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
from fastapi.responses import JSONResponse
|
| 12 |
-
from transformers import
|
| 13 |
from docx import Document as DocxDocument
|
| 14 |
from pptx import Presentation
|
| 15 |
import logging
|
|
@@ -43,20 +43,19 @@ app.add_middleware(
|
|
| 43 |
allow_headers=["*"],
|
| 44 |
)
|
| 45 |
|
| 46 |
-
MODEL_ID = "
|
| 47 |
-
|
| 48 |
-
model = None
|
| 49 |
ocr_reader = None
|
| 50 |
|
| 51 |
@app.on_event("startup")
|
| 52 |
async def load_model():
|
| 53 |
-
"""Load the model
|
| 54 |
-
global
|
| 55 |
try:
|
| 56 |
logger.info(f"Loading model: {MODEL_ID} ...")
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
MODEL_ID,
|
| 60 |
torch_dtype=torch.bfloat16,
|
| 61 |
device_map="auto"
|
| 62 |
)
|
|
@@ -381,53 +380,33 @@ Produce ONLY valid JSON with these exact fields:
|
|
| 381 |
}}"""
|
| 382 |
|
| 383 |
try:
|
| 384 |
-
|
| 385 |
-
{"role": "system", "content": system_message},
|
| 386 |
-
{"role": "user", "content": user_message}
|
| 387 |
-
]
|
| 388 |
-
|
| 389 |
-
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 390 |
-
|
| 391 |
-
prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")
|
| 392 |
-
prompt_token_count = prompt_tokens.shape[1]
|
| 393 |
-
|
| 394 |
-
max_context = 4096
|
| 395 |
-
max_input_tokens = 3800
|
| 396 |
-
|
| 397 |
-
if prompt_token_count > max_input_tokens:
|
| 398 |
-
logger.warning(f"Prompt is {prompt_token_count} tokens, truncating to {max_input_tokens}")
|
| 399 |
-
prompt_tokens = prompt_tokens[:, :max_input_tokens]
|
| 400 |
-
prompt = tokenizer.decode(prompt_tokens[0], skip_special_tokens=True)
|
| 401 |
-
prompt_token_count = max_input_tokens
|
| 402 |
-
|
| 403 |
-
max_output_tokens = max_context - prompt_token_count - 50
|
| 404 |
|
| 405 |
-
logger.info(f"Input
|
|
|
|
| 406 |
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
logger.info(f"Setting max_new_tokens to {output_limit}")
|
| 411 |
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
max_new_tokens=
|
| 415 |
temperature=0.3,
|
| 416 |
do_sample=True,
|
| 417 |
top_p=0.95,
|
| 418 |
-
|
| 419 |
-
use_cache=True
|
| 420 |
)
|
| 421 |
-
raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 422 |
|
| 423 |
-
|
| 424 |
-
|
| 425 |
|
| 426 |
start = raw_output.find('{')
|
| 427 |
end = raw_output.rfind('}') + 1
|
| 428 |
|
| 429 |
if start == -1 or end == 0:
|
| 430 |
-
|
|
|
|
| 431 |
|
| 432 |
parsed_json = json.loads(raw_output[start:end])
|
| 433 |
return parsed_json
|
|
@@ -458,43 +437,19 @@ Full Deck Length: {len(full_text)} characters
|
|
| 458 |
Produce a FINAL comprehensive review with the same JSON structure as before, consolidating all findings."""
|
| 459 |
|
| 460 |
try:
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
{"role": "user", "content": user_message}
|
| 464 |
-
]
|
| 465 |
-
|
| 466 |
-
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 467 |
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
max_context = 4096
|
| 472 |
-
max_input_tokens = 3800
|
| 473 |
-
|
| 474 |
-
if prompt_token_count > max_input_tokens:
|
| 475 |
-
logger.warning(f"Combine prompt is {prompt_token_count} tokens, truncating to {max_input_tokens}")
|
| 476 |
-
prompt_tokens = prompt_tokens[:, :max_input_tokens]
|
| 477 |
-
prompt = tokenizer.decode(prompt_tokens[0], skip_special_tokens=True)
|
| 478 |
-
prompt_token_count = max_input_tokens
|
| 479 |
-
|
| 480 |
-
max_output_tokens = max_context - prompt_token_count - 50
|
| 481 |
-
|
| 482 |
-
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=prompt_token_count).to(model.device)
|
| 483 |
-
|
| 484 |
-
output_limit = min(1500, max_output_tokens)
|
| 485 |
-
outputs = model.generate(
|
| 486 |
-
**inputs,
|
| 487 |
-
max_new_tokens=output_limit,
|
| 488 |
temperature=0.3,
|
| 489 |
do_sample=True,
|
| 490 |
top_p=0.95,
|
| 491 |
-
|
| 492 |
-
use_cache=True
|
| 493 |
)
|
| 494 |
-
raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 495 |
|
| 496 |
-
|
| 497 |
-
raw_output = raw_output.split("<|assistant|>")[-1]
|
| 498 |
|
| 499 |
start = raw_output.find('{')
|
| 500 |
end = raw_output.rfind('}') + 1
|
|
@@ -583,43 +538,19 @@ Return ONLY valid JSON:
|
|
| 583 |
}}"""
|
| 584 |
|
| 585 |
try:
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
{"role": "user", "content": user_message}
|
| 589 |
-
]
|
| 590 |
-
|
| 591 |
-
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 592 |
-
|
| 593 |
-
prompt_tokens = tokenizer.encode(prompt, return_tensors="pt")
|
| 594 |
-
prompt_token_count = prompt_tokens.shape[1]
|
| 595 |
-
|
| 596 |
-
max_context = 4096
|
| 597 |
-
max_input_tokens = 3800
|
| 598 |
-
|
| 599 |
-
if prompt_token_count > max_input_tokens:
|
| 600 |
-
logger.warning(f"Improvement prompt is {prompt_token_count} tokens, truncating to {max_input_tokens}")
|
| 601 |
-
prompt_tokens = prompt_tokens[:, :max_input_tokens]
|
| 602 |
-
prompt = tokenizer.decode(prompt_tokens[0], skip_special_tokens=True)
|
| 603 |
-
prompt_token_count = max_input_tokens
|
| 604 |
-
|
| 605 |
-
max_output_tokens = max_context - prompt_token_count - 50
|
| 606 |
-
|
| 607 |
-
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=prompt_token_count).to(model.device)
|
| 608 |
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
max_new_tokens=output_limit,
|
| 613 |
temperature=0.4,
|
| 614 |
do_sample=True,
|
| 615 |
top_p=0.95,
|
| 616 |
-
|
| 617 |
-
use_cache=True
|
| 618 |
)
|
| 619 |
-
raw_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 620 |
|
| 621 |
-
|
| 622 |
-
raw_output = raw_output.split("<|assistant|>")[-1]
|
| 623 |
|
| 624 |
start = raw_output.find('{')
|
| 625 |
end = raw_output.rfind('}') + 1
|
|
@@ -656,7 +587,7 @@ async def health():
|
|
| 656 |
"""Health check endpoint"""
|
| 657 |
return {
|
| 658 |
"status": "healthy",
|
| 659 |
-
"model_loaded":
|
| 660 |
}
|
| 661 |
|
| 662 |
@app.post("/review")
|
|
@@ -666,7 +597,7 @@ async def review_deck(file: UploadFile = File(...)):
|
|
| 666 |
|
| 667 |
Supported formats: PDF, DOCX, PPT, PPTX
|
| 668 |
"""
|
| 669 |
-
if
|
| 670 |
raise HTTPException(status_code=503, detail="Model not loaded yet. Please wait for startup to complete.")
|
| 671 |
|
| 672 |
file_extension = Path(file.filename).suffix.lower()
|
|
|
|
| 9 |
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
from fastapi.responses import JSONResponse
|
| 12 |
+
from transformers import pipeline
|
| 13 |
from docx import Document as DocxDocument
|
| 14 |
from pptx import Presentation
|
| 15 |
import logging
|
|
|
|
| 43 |
allow_headers=["*"],
|
| 44 |
)
|
| 45 |
|
| 46 |
+
MODEL_ID = "tiiuae/Falcon3-3B-Instruct"
|
| 47 |
+
pipe = None
|
|
|
|
| 48 |
ocr_reader = None
|
| 49 |
|
| 50 |
@app.on_event("startup")
|
| 51 |
async def load_model():
|
| 52 |
+
"""Load the model pipeline and OCR reader on startup"""
|
| 53 |
+
global pipe, ocr_reader
|
| 54 |
try:
|
| 55 |
logger.info(f"Loading model: {MODEL_ID} ...")
|
| 56 |
+
pipe = pipeline(
|
| 57 |
+
"text-generation",
|
| 58 |
+
model=MODEL_ID,
|
| 59 |
torch_dtype=torch.bfloat16,
|
| 60 |
device_map="auto"
|
| 61 |
)
|
|
|
|
| 380 |
}}"""
|
| 381 |
|
| 382 |
try:
|
| 383 |
+
full_prompt = f"{system_message}\n\n{user_message}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
+
logger.info(f"Input prompt length: {len(full_prompt)} characters")
|
| 386 |
+
logger.info("Starting model generation with pipeline...")
|
| 387 |
|
| 388 |
+
messages = [
|
| 389 |
+
{"role": "user", "content": full_prompt}
|
| 390 |
+
]
|
|
|
|
| 391 |
|
| 392 |
+
result = pipe(
|
| 393 |
+
messages,
|
| 394 |
+
max_new_tokens=1500,
|
| 395 |
temperature=0.3,
|
| 396 |
do_sample=True,
|
| 397 |
top_p=0.95,
|
| 398 |
+
return_full_text=False
|
|
|
|
| 399 |
)
|
|
|
|
| 400 |
|
| 401 |
+
raw_output = result[0]["generated_text"]
|
| 402 |
+
logger.info(f"✅ Generated {len(raw_output)} characters of output")
|
| 403 |
|
| 404 |
start = raw_output.find('{')
|
| 405 |
end = raw_output.rfind('}') + 1
|
| 406 |
|
| 407 |
if start == -1 or end == 0:
|
| 408 |
+
logger.warning("No JSON found in output, returning raw output")
|
| 409 |
+
raise ValueError(f"No JSON object found in model output. Raw output: {raw_output[:500]}")
|
| 410 |
|
| 411 |
parsed_json = json.loads(raw_output[start:end])
|
| 412 |
return parsed_json
|
|
|
|
| 437 |
Produce a FINAL comprehensive review with the same JSON structure as before, consolidating all findings."""
|
| 438 |
|
| 439 |
try:
|
| 440 |
+
full_prompt = f"{system_message}\n\n{user_message}"
|
| 441 |
+
messages = [{"role": "user", "content": full_prompt}]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
+
result = pipe(
|
| 444 |
+
messages,
|
| 445 |
+
max_new_tokens=1500,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
temperature=0.3,
|
| 447 |
do_sample=True,
|
| 448 |
top_p=0.95,
|
| 449 |
+
return_full_text=False
|
|
|
|
| 450 |
)
|
|
|
|
| 451 |
|
| 452 |
+
raw_output = result[0]["generated_text"]
|
|
|
|
| 453 |
|
| 454 |
start = raw_output.find('{')
|
| 455 |
end = raw_output.rfind('}') + 1
|
|
|
|
| 538 |
}}"""
|
| 539 |
|
| 540 |
try:
|
| 541 |
+
full_prompt = f"{system_message}\n\n{user_message}"
|
| 542 |
+
messages = [{"role": "user", "content": full_prompt}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
|
| 544 |
+
result = pipe(
|
| 545 |
+
messages,
|
| 546 |
+
max_new_tokens=1000,
|
|
|
|
| 547 |
temperature=0.4,
|
| 548 |
do_sample=True,
|
| 549 |
top_p=0.95,
|
| 550 |
+
return_full_text=False
|
|
|
|
| 551 |
)
|
|
|
|
| 552 |
|
| 553 |
+
raw_output = result[0]["generated_text"]
|
|
|
|
| 554 |
|
| 555 |
start = raw_output.find('{')
|
| 556 |
end = raw_output.rfind('}') + 1
|
|
|
|
| 587 |
"""Health check endpoint"""
|
| 588 |
return {
|
| 589 |
"status": "healthy",
|
| 590 |
+
"model_loaded": pipe is not None
|
| 591 |
}
|
| 592 |
|
| 593 |
@app.post("/review")
|
|
|
|
| 597 |
|
| 598 |
Supported formats: PDF, DOCX, PPT, PPTX
|
| 599 |
"""
|
| 600 |
+
if pipe is None:
|
| 601 |
raise HTTPException(status_code=503, detail="Model not loaded yet. Please wait for startup to complete.")
|
| 602 |
|
| 603 |
file_extension = Path(file.filename).suffix.lower()
|