longjava2024 commited on
Commit
fd7ebeb
·
verified ·
1 Parent(s): ce6cc1f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+
6
+ app = FastAPI()
7
+
8
+ MODEL_NAME = "5CD-AI/Vintern-1B-v2"
9
+
10
+ print("Loading tokenizer...")
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+
13
+ print("Loading model (INT4, CPU)...")
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ MODEL_NAME,
16
+ load_in_4bit=True,
17
+ device_map="cpu",
18
+ torch_dtype=torch.float16
19
+ )
20
+
21
+ class InferRequest(BaseModel):
22
+ text: str
23
+
24
+ @app.post("/infer")
25
+ def infer(req: InferRequest):
26
+ inputs = tokenizer(
27
+ req.text,
28
+ return_tensors="pt",
29
+ truncation=True,
30
+ max_length=512
31
+ )
32
+
33
+ with torch.no_grad():
34
+ output = model.generate(
35
+ **inputs,
36
+ max_new_tokens=256,
37
+ do_sample=False
38
+ )
39
+
40
+ result = tokenizer.decode(output[0], skip_special_tokens=True)
41
+ return {"result": result}