marveljo commited on
Commit
d6219de
·
verified ·
1 Parent(s): a8d29c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -12
app.py CHANGED
@@ -1,17 +1,56 @@
1
- from fastapi import FastAPI, UploadFile
2
- from transformers import pipeline
 
 
 
3
  from PIL import Image
4
- import io
5
- import uvicorn
6
 
7
- app = FastAPI()
8
- pipe = pipeline("image-classification", model="google/vit-base-patch16-224")
 
 
 
 
 
 
 
9
 
 
 
 
10
  @app.post("/predict")
11
- async def predict(file: UploadFile):
12
- img = Image.open(io.BytesIO(await file.read()))
13
- preds = pipe(img)
14
- return preds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- if __name__ == "__main__":
17
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import torch
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq
3
+ from fastapi import FastAPI, File, UploadFile, Form
4
+ from fastapi.responses import JSONResponse
5
+ from io import BytesIO
6
  from PIL import Image
 
 
7
 
8
+ # --- Load model and processor ---
9
+ model_id = "HPAI-BSC/Aloe-Vision-7B-AR"
10
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
11
+ model = AutoModelForVision2Seq.from_pretrained(
12
+ model_id,
13
+ torch_dtype=torch.bfloat16,
14
+ device_map="auto",
15
+ trust_remote_code=True,
16
+ )
17
 
18
+ app = FastAPI(title="Aloe Vision 7B AR API")
19
+
20
+ # --- Inference endpoint ---
21
  @app.post("/predict")
22
+ async def predict(
23
+ file: UploadFile = File(...),
24
+ question: str = Form("What do you see?")
25
+ ):
26
+ try:
27
+ image = Image.open(BytesIO(await file.read())).convert("RGB")
28
+
29
+ messages = [
30
+ {
31
+ "role": "user",
32
+ "content": [
33
+ {"type": "image", "image": image},
34
+ {"type": "text", "text": question},
35
+ ],
36
+ }
37
+ ]
38
+
39
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
40
+ image_inputs = processor.process_vision_info(messages)
41
+ inputs = processor(text=[text], **image_inputs, return_tensors="pt").to(model.device)
42
+
43
+ generated = model.generate(
44
+ **inputs,
45
+ max_new_tokens=256,
46
+ do_sample=False,
47
+ eos_token_id=processor.tokenizer.eos_token_id,
48
+ )
49
+
50
+ output_text = processor.batch_decode(generated, skip_special_tokens=True)[0]
51
+ answer = output_text.split(text)[-1].strip()
52
+
53
+ return JSONResponse({"answer": answer})
54
 
55
+ except Exception as e:
56
+ return JSONResponse({"error": str(e)}, status_code=500)