from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse from PIL import Image from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from peft import PeftModel import torch from qwen_vl_utils import process_vision_info import io import os app = FastAPI(title="Qwen OCR API") # --------------------------------------------------------------------- # CONFIGURATION # --------------------------------------------------------------------- MODEL_NAME = "NAMAA-Space/Qari-OCR-0.2.2.1-VL-2B-Instruct" # your fine-tuned model BASE_MODEL = "Qwen/Qwen2.5-VL-2B-Instruct" # base model for LoRA fallback # detect device device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 # --------------------------------------------------------------------- # MODEL LOADING # --------------------------------------------------------------------- print("🚀 Loading model...") try: # Try loading as a full model model = Qwen2VLForConditionalGeneration.from_pretrained( MODEL_NAME, torch_dtype=dtype, device_map=device, trust_remote_code=True ) print(f"✅ Loaded main model: {MODEL_NAME}") except Exception as e: print(f"⚠️ Direct load failed: {e}") print("➡️ Trying as LoRA/PEFT adapter...") base = Qwen2VLForConditionalGeneration.from_pretrained( BASE_MODEL, torch_dtype=dtype, device_map=device, trust_remote_code=True ) model = PeftModel.from_pretrained(base, MODEL_NAME) print(f"✅ Loaded LoRA adapter on base model: {BASE_MODEL}") # load processor processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) model.to(device) model.eval() # --------------------------------------------------------------------- # OCR ENDPOINT # --------------------------------------------------------------------- @app.post("/ocr") async def ocr(file: UploadFile = File(...)): try: # Load image image_bytes = await file.read() img = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Prompt prompt = ( "Below is an image of one page of a document. " "Return its natural text representation accurately, without hallucination." ) # Format message messages = [{ "role": "user", "content": [ {"type": "image", "image": img}, {"type": "text", "text": prompt} ] }] # Prepare model inputs text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text_prompt], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ).to(device) # Generate output with torch.no_grad(): gen_ids = model.generate(**inputs, max_new_tokens=2000) # Clean up output trimmed_ids = [o[len(i):] for i, o in zip(inputs.input_ids, gen_ids)] result = processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] return JSONResponse({"text": result}) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) # --------------------------------------------------------------------- # MAIN ENTRY (for local debugging) # --------------------------------------------------------------------- if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))