Cyber_Bot / app.py
KoKoDanio's picture
2nd commit
49354a7
import torch
from fastapi import FastAPI
from pydantic import BaseModel, Field
from unsloth import FastLanguageModel
from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList
# Initialize FastAPI app
app = FastAPI(title="Llama-3.1 Finetuned API", version="1.0.0")
# --- Model Loading ---
try:
lora_adapter_path = "cyber_llama_32"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=lora_adapter_path,
max_seq_length=2048,
load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
except Exception as e:
print(f"Error loading model: {e}")
# Set to None to handle errors gracefully in the API endpoint
model = None
tokenizer = None
# Pydantic model for request body
class PromptRequest(BaseModel):
prompt: str = Field(..., description="The user's prompt or instruction for the model.")
max_new_tokens: int = Field(512, ge=1, description="Maximum number of tokens to generate.")
stop_sequences: list[str] = Field([".", "!", "?"], description="A list of strings that will stop the generation.")
# A custom stopping criteria class for the stop sequences
class StopOnTokens(StoppingCriteria):
def __init__(self, stop_token_ids):
super().__init__()
self.stop_token_ids = stop_token_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(input_ids[0][-1] == token_id for token_id in self.stop_token_ids)
# API endpoint for text generation
@app.post("/generate", summary="Generates text based on a given prompt")
async def generate(request: PromptRequest):
if not model or not tokenizer:
return {"error": "Model not loaded. Please check the server logs."}
# The prompt template for the model
alpaca_prompt = """You are a trustworthy cybersecurity and privacy assistant that provides clear, safe, and practical guidance on protecting data, avoiding threats, and staying secure online.
### Instruction:
Analyse the user input and answer the question carefully. Please try to obey the cybersecurity and privacy laws.
### Input:
{}
### Response:
{}"""
inputs = tokenizer(
[
alpaca_prompt.format(
request.prompt, # input from the user
"", # empty response to be filled by the model
)
],
return_tensors="pt"
).to("cuda")
# Convert the stop sequences to token IDs
stop_token_ids = tokenizer.convert_tokens_to_ids(request.stop_sequences)
# Create the stopping criteria list
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_token_ids)])
outputs = model.generate(
**inputs,
max_new_tokens=request.max_new_tokens,
use_cache=True,
do_sample=True, # Recommended for better creative responses
stopping_criteria=stopping_criteria
)
# Decode the generated text
generated_text = tokenizer.batch_decode(outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True)[0]
return {"generated_text": generated_text}
# This section is for local testing and will not be run on Hugging Face Spaces
if __name__ == "__main__":
import uvicorn
# Make sure to include the ngrok setup for local testing on Colab
uvicorn.run(app, host="0.0.0.0", port=8000)