Arghya Ghosh commited on
Commit
ec9c6ea
·
verified ·
1 Parent(s): 3e6422b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -33
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import io
2
  import torch
3
  from PIL import Image
4
- from threading import Thread
5
  from fastapi import FastAPI, File, UploadFile, Form
6
  from fastapi.responses import JSONResponse
7
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
@@ -9,39 +8,31 @@ from transformers.generation.streamers import TextIteratorStreamer
9
 
10
  app = FastAPI()
11
 
12
- # Setup device
13
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
 
15
- # Load model and processor
16
  MODEL_ID = "nanonets/Nanonets-OCR-s"
17
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
18
- model = (
19
- Qwen2_5_VLForConditionalGeneration.from_pretrained(
20
- MODEL_ID,
21
- trust_remote_code=True,
22
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
- )
24
- .to(device)
25
- .eval()
26
  )
 
27
 
28
 
29
- def generate_response(image, prompt_text, **kwargs):
30
- images = [image]
31
-
32
  messages = [
33
  {
34
  "role": "user",
35
- "content": [{"type": "image"} for _ in images]
36
- + [{"type": "text", "text": prompt_text}],
37
  }
38
  ]
39
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
40
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
41
 
42
- streamer = TextIteratorStreamer(
43
- processor, skip_prompt=True, skip_special_tokens=True
44
- )
45
  generation_kwargs = {
46
  **inputs,
47
  "streamer": streamer,
@@ -52,8 +43,8 @@ def generate_response(image, prompt_text, **kwargs):
52
  "repetition_penalty": kwargs.get("repetition_penalty", 1.2),
53
  }
54
 
55
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
56
- thread.start()
57
 
58
  output = ""
59
  for chunk in streamer:
@@ -61,6 +52,11 @@ def generate_response(image, prompt_text, **kwargs):
61
  return output.strip()
62
 
63
 
 
 
 
 
 
64
  @app.post("/ocr/image")
65
  async def ocr_image(
66
  prompt: str = Form(...),
@@ -74,13 +70,16 @@ async def ocr_image(
74
  image_bytes = await image.read()
75
  pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
76
 
77
- result = generate_response(
78
- image=pil_image,
79
- prompt_text=prompt,
80
- max_new_tokens=max_new_tokens,
81
- temperature=temperature,
82
- top_p=top_p,
83
- top_k=top_k,
84
- repetition_penalty=repetition_penalty,
85
- )
86
- return JSONResponse(content={"result": result})
 
 
 
 
1
  import io
2
  import torch
3
  from PIL import Image
 
4
  from fastapi import FastAPI, File, UploadFile, Form
5
  from fastapi.responses import JSONResponse
6
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
 
8
 
9
  app = FastAPI()
10
 
11
+ # Device config
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
+ # Load processor and model
15
  MODEL_ID = "nanonets/Nanonets-OCR-s"
16
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
17
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
18
+ MODEL_ID,
19
+ trust_remote_code=True,
20
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
 
 
 
21
  )
22
+ model = model.to(device).eval()
23
 
24
 
25
+ def generate_response(image: Image.Image, prompt_text: str, **kwargs) -> str:
 
 
26
  messages = [
27
  {
28
  "role": "user",
29
+ "content": [{"type": "image"}, {"type": "text", "text": prompt_text}],
 
30
  }
31
  ]
32
  prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
33
+ inputs = processor(text=prompt, images=[image], return_tensors="pt").to(device)
34
 
35
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
 
 
36
  generation_kwargs = {
37
  **inputs,
38
  "streamer": streamer,
 
43
  "repetition_penalty": kwargs.get("repetition_penalty", 1.2),
44
  }
45
 
46
+ # 🧠 Direct call (no thread)
47
+ model.generate(**generation_kwargs)
48
 
49
  output = ""
50
  for chunk in streamer:
 
52
  return output.strip()
53
 
54
 
55
+ @app.get("/")
56
+ def health():
57
+ return {"status": "running", "model": MODEL_ID}
58
+
59
+
60
  @app.post("/ocr/image")
61
  async def ocr_image(
62
  prompt: str = Form(...),
 
70
  image_bytes = await image.read()
71
  pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
72
 
73
+ try:
74
+ result = generate_response(
75
+ image=pil_image,
76
+ prompt_text=prompt,
77
+ max_new_tokens=max_new_tokens,
78
+ temperature=temperature,
79
+ top_p=top_p,
80
+ top_k=top_k,
81
+ repetition_penalty=repetition_penalty,
82
+ )
83
+ return JSONResponse(content={"result": result})
84
+ except Exception as e:
85
+ return JSONResponse(status_code=500, content={"error": str(e)})