Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| from confluent_kafka import KafkaException, Producer | |
| import json | |
| import torch | |
| from transformers import TextStreamer, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from confluent_kafka.serialization import ( | |
| MessageField, | |
| SerializationContext, | |
| ) | |
| from unsloth import FastLanguageModel | |
| from uuid import uuid4 | |
| import concurrent.futures | |
| os.environ['CUDA_LAUNCH_BLOCKING'] = "1" | |
| hf_token = os.getenv("HF_TOKEN") | |
| class MessageSend: | |
| def __init__(self, username, title, level, detail=None): | |
| self.username = username | |
| self.title = title | |
| self.level = level | |
| self.detail = detail | |
| def cover_message(msg): | |
| """Return a dictionary representation of a User instance for serialization.""" | |
| return dict( | |
| username=msg.username, | |
| title=msg.title, | |
| level=msg.level, | |
| detail=msg.detail | |
| ) | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| class TooManyRequestsError(Exception): | |
| def __init__(self, retry_after): | |
| self.retry_after = retry_after | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = "admincybers2/sentinal", | |
| max_seq_length = 4096, | |
| dtype = None, | |
| load_in_4bit = True, | |
| token=hf_token | |
| ) | |
| # Enable native 2x faster inference | |
| FastLanguageModel.for_inference(model) | |
| vulnerable_prompt = "Identify the line of code that is vulnerable and describe the type of software vulnerability, no yapping if no vulnerable code found pls return 'no vulnerable'\n### Code Snippet:\n{}\n### Vulnerability Description:\n{}" | |
| def extract_data(full_message): | |
| try: | |
| message = json.loads(full_message) | |
| return message | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Failed to extract data: {e}") | |
| return None | |
| def perform_ai_task(question): | |
| prompt = vulnerable_prompt.format(question, "") | |
| inputs = tokenizer([prompt], return_tensors="pt") | |
| text_streamer = TextStreamer(tokenizer) | |
| try: | |
| model_output = model.generate( | |
| **inputs, | |
| streamer=text_streamer, | |
| use_cache=True, | |
| max_new_tokens=640, | |
| temperature=0.5, | |
| top_k=50, | |
| top_p=0.9, | |
| min_p=0.01, | |
| typical_p=0.95, | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=3, | |
| ) | |
| generated_text = tokenizer.decode(model_output[0], skip_special_tokens=True) | |
| except RuntimeError as e: | |
| error_message = str(e) | |
| if "probability tensor contains either `inf`, `nan` or element < 0" in error_message: | |
| logger.error("Encountered probability tensor error, skipping this task.") | |
| return None | |
| else: | |
| logger.error(f"Runtime error during model generation: {error_message}. Switching to remote inference.") | |
| deduplicated_text = deduplicate_text(generated_text) | |
| return { | |
| "detail": deduplicated_text | |
| } | |
| def deduplicate_text(text): | |
| sentences = text.split('. ') | |
| seen_sentences = set() | |
| deduplicated_sentences = [] | |
| for sentence in sentences: | |
| if sentence not in seen_sentences: | |
| seen_sentences.add(sentence) | |
| deduplicated_sentences.append(sentence) | |
| return '. '.join(deduplicated_sentences) + '.' | |
| def delivery_report(err, msg): | |
| if err is not None: | |
| logger.error(f"Message delivery failed: {err}") | |
| else: | |
| logger.info(f"Message delivered to {msg.topic()} [{msg.partition()}]") | |
| def handle_message(msg, producer, ensure_producer_connected, avro_serializer): | |
| logger.info(f'Message value {msg}') | |
| if msg: | |
| ensure_producer_connected(producer) | |
| try: | |
| ai_results = perform_ai_task(msg['message_send']) | |
| if ai_results is None: | |
| logger.error("AI task skipped due to an error in model generation.") | |
| return | |
| detail = ai_results.get("detail", "No details available") | |
| topic = "get_scan_message" | |
| messagedict = cover_message( | |
| MessageSend( | |
| username=msg['username'], | |
| title=msg['path'], | |
| level='', | |
| detail=detail | |
| ) | |
| ) | |
| if messagedict: | |
| byte_value = avro_serializer(messagedict, SerializationContext(topic, MessageField.VALUE)) | |
| producer.produce( | |
| topic, | |
| value=byte_value, | |
| headers={"correlation_id": str(uuid4())}, | |
| callback=delivery_report | |
| ) | |
| producer.flush() | |
| else: | |
| logger.error("Message serialization failed; skipping production.") | |
| except KafkaException as e: | |
| logger.error(f"Kafka error producing message: {e}") | |
| except Exception as e: | |
| logger.error(f"Unhandled error in handle_message: {e}") |