File size: 2,055 Bytes
556b039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os, time, uuid, json
import torch
from unsloth import FastModel
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware

MODEL = "iryahayri/lila-mirror-v1-merged"
SERVED = "lila-v1"

print("Loading Lila...")
model, tokenizer = FastModel.from_pretrained(
    model_name=MODEL,
    max_seq_length=4096,
    load_in_4bit=False,
    dtype=torch.bfloat16,
)
model.eval()
tok = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer
print("Model loaded.")

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/v1/models")
async def models():
    return {"object": "list", "data": [{"id": SERVED, "object": "model", "owned_by": "hayri"}]}

@app.post("/v1/chat/completions")
async def chat(req: Request):
    body = await req.json()
    msgs = body.get("messages", [])
    prompt = tok.apply_chat_template(
        msgs,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False,
    )
    inputs = tok(prompt, return_tensors="pt").to("cuda")
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=body.get("max_tokens", 200),
            temperature=body.get("temperature", 0.85),
            top_p=body.get("top_p", 0.9),
            repetition_penalty=body.get("repetition_penalty", 1.15),
            do_sample=True,
            pad_token_id=tok.eos_token_id,
        )
    text = tok.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return JSONResponse({
        "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
        "object": "chat.completion",
        "created": int(time.time()),
        "model": SERVED,
        "choices": [{
            "index": 0,
            "message": {"role": "assistant", "content": text},
            "finish_reason": "stop",
        }],
    })

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)