|
|
import torch |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel, Field |
|
|
from unsloth import FastLanguageModel |
|
|
from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList |
|
|
|
|
|
|
|
|
app = FastAPI(title="Llama-3.1 Finetuned API", version="1.0.0") |
|
|
|
|
|
|
|
|
try: |
|
|
lora_adapter_path = "cyber_llama_32" |
|
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
|
model_name=lora_adapter_path, |
|
|
max_seq_length=2048, |
|
|
load_in_4bit=True, |
|
|
) |
|
|
FastLanguageModel.for_inference(model) |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
|
|
|
class PromptRequest(BaseModel): |
|
|
prompt: str = Field(..., description="The user's prompt or instruction for the model.") |
|
|
max_new_tokens: int = Field(512, ge=1, description="Maximum number of tokens to generate.") |
|
|
stop_sequences: list[str] = Field([".", "!", "?"], description="A list of strings that will stop the generation.") |
|
|
|
|
|
|
|
|
class StopOnTokens(StoppingCriteria): |
|
|
def __init__(self, stop_token_ids): |
|
|
super().__init__() |
|
|
self.stop_token_ids = stop_token_ids |
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
|
|
return any(input_ids[0][-1] == token_id for token_id in self.stop_token_ids) |
|
|
|
|
|
|
|
|
@app.post("/generate", summary="Generates text based on a given prompt") |
|
|
async def generate(request: PromptRequest): |
|
|
if not model or not tokenizer: |
|
|
return {"error": "Model not loaded. Please check the server logs."} |
|
|
|
|
|
|
|
|
alpaca_prompt = """You are a trustworthy cybersecurity and privacy assistant that provides clear, safe, and practical guidance on protecting data, avoiding threats, and staying secure online. |
|
|
|
|
|
### Instruction: |
|
|
Analyse the user input and answer the question carefully. Please try to obey the cybersecurity and privacy laws. |
|
|
|
|
|
### Input: |
|
|
{} |
|
|
|
|
|
### Response: |
|
|
{}""" |
|
|
|
|
|
inputs = tokenizer( |
|
|
[ |
|
|
alpaca_prompt.format( |
|
|
request.prompt, |
|
|
"", |
|
|
) |
|
|
], |
|
|
return_tensors="pt" |
|
|
).to("cuda") |
|
|
|
|
|
|
|
|
stop_token_ids = tokenizer.convert_tokens_to_ids(request.stop_sequences) |
|
|
|
|
|
|
|
|
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) |
|
|
|
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=request.max_new_tokens, |
|
|
use_cache=True, |
|
|
do_sample=True, |
|
|
stopping_criteria=stopping_criteria |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = tokenizer.batch_decode(outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True)[0] |
|
|
|
|
|
return {"generated_text": generated_text} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |