Arghya Ghosh commited on
Commit
c00f569
·
verified ·
1 Parent(s): bf54927

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -83
main.py CHANGED
@@ -1,94 +1,48 @@
1
- import io
2
- import torch
3
- from PIL import Image
4
- from fastapi import FastAPI, File, UploadFile, Form
5
- from fastapi.responses import JSONResponse
6
- from fastapi.middleware.cors import CORSMiddleware
7
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
8
- from transformers.generation.streamers import TextIteratorStreamer
 
9
 
10
  app = FastAPI()
11
 
12
- app.add_middleware(
13
- CORSMiddleware,
14
- allow_origins=["*"],
15
- allow_credentials=True,
16
- allow_methods=["*"],
17
- allow_headers=["*"],
 
18
  )
19
-
20
- # Device config
21
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
-
23
- # Load processor and model
24
- MODEL_ID = "nanonets/Nanonets-OCR-s"
25
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
26
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
27
- MODEL_ID,
28
- trust_remote_code=True,
29
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
30
  )
31
- model = model.to(device).eval()
32
 
 
 
 
33
 
34
- def generate_response(image: Image.Image, prompt_text: str, **kwargs) -> str:
 
35
  messages = [
36
- {
37
- "role": "user",
38
- "content": [{"type": "image"}, {"type": "text", "text": prompt_text}],
39
- }
40
  ]
41
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
42
- inputs = processor(text=prompt, images=[image], return_tensors="pt").to(device)
43
-
44
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
45
- generation_kwargs = {
46
- **inputs,
47
- "streamer": streamer,
48
- "max_new_tokens": kwargs.get("max_new_tokens", 1024),
49
- "temperature": kwargs.get("temperature", 0.6),
50
- "top_p": kwargs.get("top_p", 0.9),
51
- "top_k": kwargs.get("top_k", 50),
52
- "repetition_penalty": kwargs.get("repetition_penalty", 1.2),
53
- }
54
-
55
- # 🧠 Direct call (no thread)
56
- model.generate(**generation_kwargs)
57
-
58
- output = ""
59
- for chunk in streamer:
60
- output += chunk.replace("<|im_end|>", "")
61
- return output.strip()
62
-
63
-
64
- @app.get("/")
65
- def health():
66
- return {"status": "running", "model": MODEL_ID}
67
-
68
-
69
- @app.post("/ocr/image")
70
- async def ocr_image(
71
- prompt: str = Form(...),
72
- image: UploadFile = File(...),
73
- max_new_tokens: int = Form(1024),
74
- temperature: float = Form(0.6),
75
- top_p: float = Form(0.9),
76
- top_k: int = Form(50),
77
- repetition_penalty: float = Form(1.2),
78
- ):
79
- image_bytes = await image.read()
80
- pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
81
-
82
- try:
83
- result = generate_response(
84
- image=pil_image,
85
- prompt_text=prompt,
86
- max_new_tokens=max_new_tokens,
87
- temperature=temperature,
88
- top_p=top_p,
89
- top_k=top_k,
90
- repetition_penalty=repetition_penalty,
91
- )
92
- return JSONResponse(content={"result": result})
93
- except Exception as e:
94
- return JSONResponse(status_code=500, content={"error": str(e)})
 
1
+ from fastapi import FastAPI, Query
 
 
 
 
 
2
  from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
3
+ from qwen_vl_utils import process_vision_info
4
+ import torch
5
 
6
  app = FastAPI()
7
 
8
+ checkpoint = "Qwen/Qwen2.5-VL-3B-Instruct"
9
+ min_pixels = 256*28*28
10
+ max_pixels = 1280*28*28
11
+ processor = AutoProcessor.from_pretrained(
12
+ checkpoint,
13
+ min_pixels=min_pixels,
14
+ max_pixels=max_pixels
15
  )
 
 
 
 
 
 
 
16
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
+ checkpoint,
18
+ torch_dtype=torch.bfloat16,
19
+ device_map="auto",
20
+ # attn_implementation="flash_attention_2",
21
  )
 
22
 
23
+ @app.get("/")
24
+ def read_root():
25
+ return {"message": "API is live. Use the /predict endpoint."}
26
 
27
+ @app.get("/predict")
28
+ def predict(image_url: str = Query(...), prompt: str = Query(...)):
29
  messages = [
30
+ {"role": "system", "content": "You are a helpful assistant with vision abilities."},
31
+ {"role": "user", "content": [{"type": "image", "image": image_url}, {"type": "text", "text": prompt}]},
 
 
32
  ]
33
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
34
+ image_inputs, video_inputs = process_vision_info(messages)
35
+ inputs = processor(
36
+ text=[text],
37
+ images=image_inputs,
38
+ videos=video_inputs,
39
+ padding=True,
40
+ return_tensors="pt",
41
+ ).to(model.device)
42
+ with torch.no_grad():
43
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
44
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
45
+ output_texts = processor.batch_decode(
46
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
47
+ )
48
+ return {"response": output_texts[0]}