File size: 1,971 Bytes
11e76a2 9a88fe1 11e76a2 1cb6134 9a88fe1 2663f4d 9a88fe1 a7d89ae 1cb6134 9a88fe1 1cb6134 11e76a2 331704a 11e76a2 ac2580c 11e76a2 29b6ec6 11e76a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from fastapi.middleware.cors import CORSMiddleware
from peft import PeftModel
# -------------------------------
# Load model & tokenizer from HF Hub
# -------------------------------
#base_model_name = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" # change if you used another
model_name = "thedeba/Friday"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
# Load LoRA on top of base
#model = PeftModel.from_pretrained(base_model, lora_model_name)
model.to(device)
# -------------------------------
# FastAPI setup
# -------------------------------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Query(BaseModel):
text: str
@app.post("/generate")
def generate(query: Query):
messages = [{"role": "user", "content": query.text}]
# Convert to model input using chat template
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(device)
# Generate
outputs = model.generate(
input_ids=inputs,
max_new_tokens=2048,
use_cache=True,
temperature=0.5,
min_p=0.1,
)
# Decode & extract assistant response
output_string = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
response = output_string.split("assistant")[-1].strip()
return {"response": response}
@app.get("/")
def root():
return {"Friday": "is running!"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |