Spaces:
Sleeping
Sleeping
File size: 3,993 Bytes
6785323 03c32a0 6785323 03c32a0 6785323 03c32a0 6785323 03c32a0 6785323 757def3 ddc6a1e 6785323 ddc6a1e 6785323 ddc6a1e 6785323 757def3 6785323 03c32a0 6785323 03c32a0 6785323 03c32a0 6785323 7344d10 6785323 7344d10 6785323 7344d10 6785323 03c32a0 6785323 03c32a0 6785323 03c32a0 6785323 03c32a0 6785323 03c32a0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | # from fastapi import FastAPI, UploadFile, File
# # from transformers import AutoProcessor, AutoModelForVision2Seq,AutoModel
# from transformers import AutoProcessor,AutoModel
# from PIL import Image
# import torch
# import io
# app = FastAPI()
# MODEL_ID = "zai-org/GLM-OCR"
# print("Loading GLM-OCR model...")
# # processor = AutoProcessor.from_pretrained(MODEL_ID)
# # model = AutoModelForVision2Seq.from_pretrained(
# # MODEL_ID,
# # torch_dtype=torch.float32
# # )
# # processor = AutoProcessor.from_pretrained(
# # MODEL_ID,
# # trust_remote_code=True
# # )
# # model = AutoModelForVision2Seq.from_pretrained(
# # MODEL_ID,
# # trust_remote_code=True,
# # torch_dtype=torch.float32
# # )
# processor = AutoProcessor.from_pretrained(
# MODEL_ID,
# trust_remote_code=True
# )
# model = AutoModel.from_pretrained(
# MODEL_ID,
# trust_remote_code=True
# )
# @app.get("/")
# async def root():
# return {"status": "GLM-OCR API is running"}
# @app.post("/ocr")
# async def extract_text(file: UploadFile = File(...)):
# try:
# contents = await file.read()
# image = Image.open(io.BytesIO(contents)).convert("RGB")
# # inputs = processor(images=image, return_tensors="pt")
# inputs = processor(
# text="Extract all text from the document",
# images=image,
# return_tensors="pt"
# )
# with torch.no_grad():
# outputs = model.generate(**inputs, max_new_tokens=1024)
# text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
# return {
# "success": True,
# "text": text
# }
# except Exception as e:
# return {
# "success": False,
# "error": str(e)
# }
from fastapi import FastAPI, UploadFile, File
from transformers import AutoProcessor, GlmOcrForConditionalGeneration
from PIL import Image
import torch
import io
app = FastAPI()
MODEL_ID = "zai-org/GLM-OCR"
print("Loading GLM-OCR model...")
# Initialize Processor and Model specifically for GLM-OCR
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
model = GlmOcrForConditionalGeneration.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float32 # Use torch.bfloat16 if you have a GPU
).eval()
@app.get("/")
async def root():
return {"status": "GLM-OCR API is running"}
@app.post("/ocr")
async def extract_text(file: UploadFile = File(...)):
try:
# Read and prepare image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# 1. Define the conversation structure
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Extract all text from this image."}
],
}
]
# 2. Use the chat template to prepare inputs
# This fixes the 'NoneType' error by providing valid input_ids
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
)
# 3. Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=1024,
do_sample=False
)
# 4. Decode the result
# We slice the output to remove the prompt tokens and keep only the response
generated_ids = outputs[:, inputs['input_ids'].shape[1]:]
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return {
"success": True,
"text": text.strip()
}
except Exception as e:
return {
"success": False,
"error": str(e)
} |