File size: 1,788 Bytes
fca06e0
 
7017ee7
fca06e0
6266b63
c65856a
7017ee7
fca06e0
 
 
 
7017ee7
 
fca06e0
50cfe8b
fca06e0
 
 
 
0074895
c65856a
7017ee7
6266b63
 
 
 
 
 
 
 
 
 
fca06e0
 
89a3d56
fca06e0
 
 
 
 
7017ee7
89a3d56
548d89d
89a3d56
 
7017ee7
89a3d56
 
 
 
f564f0f
89a3d56
a34d277
 
89a3d56
 
 
fca06e0
7017ee7
e20dee4
1ff1004
89a3d56
fca06e0
 
1ff1004
09b64bd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from fastapi.middleware.cors import CORSMiddleware


# -------------------------------
# Load model & tokenizer from HF Hub
# -------------------------------
model_name = "thedeba/debai"  # HF Hub model path
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
device = "cpu"  # Spaces free tier uses CPU; you can switch to "cuda" if GPU granted
#model.to(device)

# -------------------------------
# FastAPI setup
# -------------------------------
app = FastAPI()


app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # or ["https://<username>.github.io"]
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)



class Query(BaseModel):
    text: str

@app.post("/generate")
def generate(query: Query):
    messages = [{"role": "user", "content": query.text}]

    # Convert to model input using chat template
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(device)

    # Generate
    outputs = model.generate(
        input_ids=inputs,
        max_new_tokens=200,
        use_cache=True,
        temperature=0.5,
        do_sample=True,
        min_p=0.1,
    )

    # Decode & extract assistant response
    output_string = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    response = output_string.split("assistant")[-1].strip()
    return {"response": response}

@app.get("/")
def root():
    return {"debai": "API is running!"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)