Nick-2x's picture
Update app.py
7344d10 verified
# from fastapi import FastAPI, UploadFile, File
# # from transformers import AutoProcessor, AutoModelForVision2Seq,AutoModel
# from transformers import AutoProcessor,AutoModel
# from PIL import Image
# import torch
# import io
# app = FastAPI()
# MODEL_ID = "zai-org/GLM-OCR"
# print("Loading GLM-OCR model...")
# # processor = AutoProcessor.from_pretrained(MODEL_ID)
# # model = AutoModelForVision2Seq.from_pretrained(
# # MODEL_ID,
# # torch_dtype=torch.float32
# # )
# # processor = AutoProcessor.from_pretrained(
# # MODEL_ID,
# # trust_remote_code=True
# # )
# # model = AutoModelForVision2Seq.from_pretrained(
# # MODEL_ID,
# # trust_remote_code=True,
# # torch_dtype=torch.float32
# # )
# processor = AutoProcessor.from_pretrained(
# MODEL_ID,
# trust_remote_code=True
# )
# model = AutoModel.from_pretrained(
# MODEL_ID,
# trust_remote_code=True
# )
# @app.get("/")
# async def root():
# return {"status": "GLM-OCR API is running"}
# @app.post("/ocr")
# async def extract_text(file: UploadFile = File(...)):
# try:
# contents = await file.read()
# image = Image.open(io.BytesIO(contents)).convert("RGB")
# # inputs = processor(images=image, return_tensors="pt")
# inputs = processor(
# text="Extract all text from the document",
# images=image,
# return_tensors="pt"
# )
# with torch.no_grad():
# outputs = model.generate(**inputs, max_new_tokens=1024)
# text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
# return {
# "success": True,
# "text": text
# }
# except Exception as e:
# return {
# "success": False,
# "error": str(e)
# }
from fastapi import FastAPI, UploadFile, File
from transformers import AutoProcessor, GlmOcrForConditionalGeneration
from PIL import Image
import torch
import io
app = FastAPI()
MODEL_ID = "zai-org/GLM-OCR"
print("Loading GLM-OCR model...")
# Initialize Processor and Model specifically for GLM-OCR
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = GlmOcrForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float32 # Use torch.bfloat16 if you have a GPU
).eval()
@app.get("/")
async def root():
return {"status": "GLM-OCR API is running"}
@app.post("/ocr")
async def extract_text(file: UploadFile = File(...)):
try:
# Read and prepare image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# 1. Define the conversation structure
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Extract all text from this image."}
],
}
]
# 2. Use the chat template to prepare inputs
# This fixes the 'NoneType' error by providing valid input_ids
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
# 3. Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=False
)
# 4. Decode the result
# We slice the output to remove the prompt tokens and keep only the response
generated_ids = outputs[:, inputs['input_ids'].shape[1]:]
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return {
"success": True,
"text": text.strip()
}
except Exception as e:
return {
"success": False,
"error": str(e)
}