File size: 1,788 Bytes
fca06e0 7017ee7 fca06e0 6266b63 c65856a 7017ee7 fca06e0 7017ee7 fca06e0 50cfe8b fca06e0 0074895 c65856a 7017ee7 6266b63 fca06e0 89a3d56 fca06e0 7017ee7 89a3d56 548d89d 89a3d56 7017ee7 89a3d56 f564f0f 89a3d56 a34d277 89a3d56 fca06e0 7017ee7 e20dee4 1ff1004 89a3d56 fca06e0 1ff1004 09b64bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 | from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from fastapi.middleware.cors import CORSMiddleware
# -------------------------------
# Load model & tokenizer from HF Hub
# -------------------------------
model_name = "thedeba/debai" # HF Hub model path
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
device = "cpu" # Spaces free tier uses CPU; you can switch to "cuda" if GPU granted
#model.to(device)
# -------------------------------
# FastAPI setup
# -------------------------------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # or ["https://<username>.github.io"]
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Query(BaseModel):
text: str
@app.post("/generate")
def generate(query: Query):
messages = [{"role": "user", "content": query.text}]
# Convert to model input using chat template
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(device)
# Generate
outputs = model.generate(
input_ids=inputs,
max_new_tokens=200,
use_cache=True,
temperature=0.5,
do_sample=True,
min_p=0.1,
)
# Decode & extract assistant response
output_string = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
response = output_string.split("assistant")[-1].strip()
return {"response": response}
@app.get("/")
def root():
return {"debai": "API is running!"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|