Leon4gr45 commited on
Commit
2a7a917
·
verified ·
1 Parent(s): 6947e38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -9
app.py CHANGED
@@ -6,21 +6,23 @@ from transformers import AutoTokenizer, AutoModel
6
  from PIL import Image
7
  from typing import Optional
8
  import io
9
- import torch
10
 
11
  app = FastAPI()
12
 
13
- device = "cpu"
 
14
  model_id = "OpenGVLab/InternVL2_5-2B"
15
 
16
- # 2. Update the loading line to force bfloat16 and use low_cpu_mem_usage
17
  model = AutoModel.from_pretrained(
18
  model_id,
19
  trust_remote_code=True,
20
- torch_dtype=torch.bfloat16,
21
  low_cpu_mem_usage=True,
22
- # revision="65b9340"
23
- ).to(device).eval()
 
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
25
 
26
  class GenerateRequest(BaseModel):
@@ -30,7 +32,6 @@ class GenerateRequest(BaseModel):
30
 
31
  @app.post("/generate")
32
  async def generate(image: UploadFile = File(...), request: GenerateRequest = Depends()):
33
- # Read image
34
  image_bytes = await image.read()
35
  pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
36
 
@@ -39,7 +40,9 @@ async def generate(image: UploadFile = File(...), request: GenerateRequest = Dep
39
  else:
40
  prompt = f"<s><image>\nDescribe the image.</s>"
41
 
42
- inputs = tokenizer(prompt, pil_image, return_tensors="pt").to(device)
 
 
43
 
44
  generation_args = {
45
  "max_new_tokens": request.max_new_tokens,
@@ -54,4 +57,4 @@ async def generate(image: UploadFile = File(...), request: GenerateRequest = Dep
54
 
55
  @app.get("/")
56
  async def read_root():
57
- return {"message": "InternVL2_5-2B API. Go to /docs for API documentation."}
 
6
  from PIL import Image
7
  from typing import Optional
8
  import io
 
9
 
10
  app = FastAPI()
11
 
12
+ # "device" variable is no longer needed for the model,
13
+ # accelerate handles it via device_map="auto".
14
  model_id = "OpenGVLab/InternVL2_5-2B"
15
 
16
+ # 1. Load the Model with smart memory management
17
  model = AutoModel.from_pretrained(
18
  model_id,
19
  trust_remote_code=True,
20
+ torch_dtype=torch.bfloat16,
21
  low_cpu_mem_usage=True,
22
+ device_map="auto", # <--- THIS IS THE FIX: Prevents OOM by handling memory placement automatically
23
+ offload_folder="offload" # <--- Optional: Explicit folder for offloading if RAM is full
24
+ ).eval() # Removed .to(device) as device_map handles this
25
+
26
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
27
 
28
  class GenerateRequest(BaseModel):
 
32
 
33
  @app.post("/generate")
34
  async def generate(image: UploadFile = File(...), request: GenerateRequest = Depends()):
 
35
  image_bytes = await image.read()
36
  pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
37
 
 
40
  else:
41
  prompt = f"<s><image>\nDescribe the image.</s>"
42
 
43
+ # 2. Update inputs to use the model's device dynamically
44
+ # model.device will tell us where the model (or its first layer) lives
45
+ inputs = tokenizer(prompt, pil_image, return_tensors="pt").to(model.device)
46
 
47
  generation_args = {
48
  "max_new_tokens": request.max_new_tokens,
 
57
 
58
  @app.get("/")
59
  async def read_root():
60
+ return {"message": "InternVL2_5-2B API. Go to /docs for API documentation."}