vaccineAI / app.py
fansa34's picture
Update app.py
ab758b3 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from fastapi import FastAPI, Request
from pydantic import BaseModel
app = FastAPI()
# Configs
BASE_MODEL = "mistralai/Mistral-7B-v0.1"
ADAPTER_MODEL = "fansa34/finetunedModel"
# Quantization for 4-bit loading (QLoRA)
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# Load tokenizer and base model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
quantization_config=quant_config,
)
# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
model.eval()
# Request schema
class QueryRequest(BaseModel):
question: str
max_new_tokens: int = 200
temperature: float = 0.6
@app.post("/ask")
async def ask(req: QueryRequest):
prompt = f"Question: {req.question}\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
do_sample=True,
top_p=0.9,
top_k=50,
repetition_penalty=1.1,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(output[0], skip_special_tokens=True).split("Answer:")[-1].strip()
return {"question": req.question, "answer": response}