Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer | |
| import spaces | |
| # Define the model configuration | |
| model_config = { | |
| "model_name": "admincybers2/sentinal", | |
| "max_seq_length": 1024, | |
| "dtype": torch.float16, | |
| "load_in_4bit": True | |
| } | |
| # Hugging Face token | |
| hf_token = os.getenv("HF_TOKEN") | |
| # Load the model when the application starts | |
| loaded_model = None | |
| loaded_tokenizer = None | |
| def load_model(): | |
| global loaded_model, loaded_tokenizer | |
| if loaded_model is None: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_config["model_name"], | |
| torch_dtype=model_config["dtype"], | |
| device_map="auto", | |
| use_auth_token=hf_token | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_config["model_name"], | |
| use_auth_token=hf_token | |
| ) | |
| loaded_model = model | |
| loaded_tokenizer = tokenizer | |
| return loaded_model, loaded_tokenizer | |
| # Vulnerability prompt template | |
| vulnerability_prompt = """Identify the specific line of code that is vulnerable and describe the type of software vulnerability. | |
| ### Vulnerable Line: | |
| {} | |
| ### Vulnerability Description: | |
| """ | |
| def predict(prompt): | |
| model, tokenizer = load_model() | |
| formatted_prompt = vulnerability_prompt.format(prompt) # Ensure this matches the correct number of placeholders | |
| inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cuda") | |
| text_streamer = TextStreamer(tokenizer) | |
| output = model.generate( | |
| **inputs, | |
| streamer=text_streamer, | |
| use_cache=True, | |
| temperature=0.4, | |
| top_k=50, # Default value, considers the top 50 most likely next tokens | |
| top_p=0.9, # Nucleus sampling, focuses on the most likely token set | |
| min_p=0.01, # Ensures that tokens below this probability are less likely to be selected | |
| typical_p=0.95, # Focuses on tokens that are most typical given the context | |
| repetition_penalty=1.2, # Penalizes repetitive sequences to improve text diversity | |
| no_repeat_ngram_size=3, # Prevents the same 3-gram sequence from repeating | |
| renormalize_logits=True, # Ensures logits are normalized after processing | |
| max_new_tokens=640 | |
| ) | |
| return tokenizer.decode(output[0], skip_special_tokens=True) | |
| theme = gr.themes.Default( | |
| primary_hue=gr.themes.colors.rose, | |
| secondary_hue=gr.themes.colors.blue, | |
| font=gr.themes.GoogleFont("Source Sans Pro") | |
| ) | |
| # Pre-load the model | |
| load_model() | |
| with gr.Blocks(theme=theme) as demo: | |
| prompt = gr.Textbox(lines=5, placeholder="Enter your code snippet or topic here...", label="Prompt") | |
| generated_text = gr.Textbox(label="Generated Text") | |
| generate_button = gr.Button("Generate") | |
| generate_button.click(predict, inputs=[prompt], outputs=generated_text) | |
| gr.Examples( | |
| examples=[ | |
| ["$buff = 'A' x 10000;\nopen(myfile, '>>PASS.PK2');\nprint myfile $buff;\nclose(myfile);"] | |
| ], | |
| inputs=[prompt] | |
| ) | |
| demo.queue(default_concurrency_limit=10).launch( | |
| server_name="0.0.0.0", | |
| allowed_paths=["/"] | |
| ) |