Arghya Ghosh commited on
Commit
06cff9f
·
verified ·
1 Parent(s): 0fd8e2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -4
app.py CHANGED
@@ -1,7 +1,86 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import torch
3
+ from PIL import Image
4
+ from threading import Thread
5
+ from fastapi import FastAPI, File, UploadFile, Form
6
+ from fastapi.responses import JSONResponse
7
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
8
+ from transformers.generation.streamers import TextIteratorStreamer
9
 
10
  app = FastAPI()
11
 
12
+ # Setup device
13
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+
15
+ # Load model and processor
16
+ MODEL_ID = "nanonets/Nanonets-OCR-s"
17
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
18
+ model = (
19
+ Qwen2_5_VLForConditionalGeneration.from_pretrained(
20
+ MODEL_ID,
21
+ trust_remote_code=True,
22
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
+ )
24
+ .to(device)
25
+ .eval()
26
+ )
27
+
28
+
29
+ def generate_response(image, prompt_text, **kwargs):
30
+ images = [image]
31
+
32
+ messages = [
33
+ {
34
+ "role": "user",
35
+ "content": [{"type": "image"} for _ in images]
36
+ + [{"type": "text", "text": prompt_text}],
37
+ }
38
+ ]
39
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
40
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
41
+
42
+ streamer = TextIteratorStreamer(
43
+ processor, skip_prompt=True, skip_special_tokens=True
44
+ )
45
+ generation_kwargs = {
46
+ **inputs,
47
+ "streamer": streamer,
48
+ "max_new_tokens": kwargs.get("max_new_tokens", 1024),
49
+ "temperature": kwargs.get("temperature", 0.6),
50
+ "top_p": kwargs.get("top_p", 0.9),
51
+ "top_k": kwargs.get("top_k", 50),
52
+ "repetition_penalty": kwargs.get("repetition_penalty", 1.2),
53
+ }
54
+
55
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
56
+ thread.start()
57
+
58
+ output = ""
59
+ for chunk in streamer:
60
+ output += chunk.replace("<|im_end|>", "")
61
+ return output.strip()
62
+
63
+
64
+ @app.post("/ocr/image")
65
+ async def ocr_image(
66
+ prompt: str = Form(...),
67
+ image: UploadFile = File(...),
68
+ max_new_tokens: int = Form(1024),
69
+ temperature: float = Form(0.6),
70
+ top_p: float = Form(0.9),
71
+ top_k: int = Form(50),
72
+ repetition_penalty: float = Form(1.2),
73
+ ):
74
+ image_bytes = await image.read()
75
+ pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
76
+
77
+ result = generate_response(
78
+ image=pil_image,
79
+ prompt_text=prompt,
80
+ max_new_tokens=max_new_tokens,
81
+ temperature=temperature,
82
+ top_p=top_p,
83
+ top_k=top_k,
84
+ repetition_penalty=repetition_penalty,
85
+ )
86
+ return JSONResponse(content={"result": result})