File size: 3,344 Bytes
b49a99c
 
 
 
 
 
 
 
 
 
 
49354a7
b49a99c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
from fastapi import FastAPI
from pydantic import BaseModel, Field
from unsloth import FastLanguageModel
from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList

# Initialize FastAPI app
app = FastAPI(title="Llama-3.1 Finetuned API", version="1.0.0")

# --- Model Loading ---
try:
    lora_adapter_path = "cyber_llama_32"
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=lora_adapter_path,
        max_seq_length=2048,
        load_in_4bit=True,
    )
    FastLanguageModel.for_inference(model)
except Exception as e:
    print(f"Error loading model: {e}")
    # Set to None to handle errors gracefully in the API endpoint
    model = None
    tokenizer = None

# Pydantic model for request body
class PromptRequest(BaseModel):
    prompt: str = Field(..., description="The user's prompt or instruction for the model.")
    max_new_tokens: int = Field(512, ge=1, description="Maximum number of tokens to generate.")
    stop_sequences: list[str] = Field([".", "!", "?"], description="A list of strings that will stop the generation.")

# A custom stopping criteria class for the stop sequences
class StopOnTokens(StoppingCriteria):
    def __init__(self, stop_token_ids):
        super().__init__()
        self.stop_token_ids = stop_token_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        return any(input_ids[0][-1] == token_id for token_id in self.stop_token_ids)

# API endpoint for text generation
@app.post("/generate", summary="Generates text based on a given prompt")
async def generate(request: PromptRequest):
    if not model or not tokenizer:
        return {"error": "Model not loaded. Please check the server logs."}

    # The prompt template for the model
    alpaca_prompt = """You are a trustworthy cybersecurity and privacy assistant that provides clear, safe, and practical guidance on protecting data, avoiding threats, and staying secure online.

### Instruction:
Analyse the user input and answer the question carefully. Please try to obey the cybersecurity and privacy laws.

### Input:
{}

### Response:
{}"""

    inputs = tokenizer(
        [
            alpaca_prompt.format(
                request.prompt,  # input from the user
                "",  # empty response to be filled by the model
            )
        ],
        return_tensors="pt"
    ).to("cuda")

    # Convert the stop sequences to token IDs
    stop_token_ids = tokenizer.convert_tokens_to_ids(request.stop_sequences)
    
    # Create the stopping criteria list
    stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_token_ids)])

    outputs = model.generate(
        **inputs,
        max_new_tokens=request.max_new_tokens,
        use_cache=True,
        do_sample=True, # Recommended for better creative responses
        stopping_criteria=stopping_criteria
    )

    # Decode the generated text
    generated_text = tokenizer.batch_decode(outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True)[0]
    
    return {"generated_text": generated_text}

# This section is for local testing and will not be run on Hugging Face Spaces
if __name__ == "__main__":
    import uvicorn
    # Make sure to include the ngrok setup for local testing on Colab
    uvicorn.run(app, host="0.0.0.0", port=8000)