File size: 2,824 Bytes
b48a35b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import traceback

MERGED_MODEL_PATH = "./merged_tinyllama_logger"

SAMPLE_LOG = """2023-03-06 15:38:41 ERROR [Worker-11] org.hibernate.exception.ConstraintViolationException at at com.example.CacheManager.land(CacheManager.java:359) at at com.example.ShippingService.discover(CacheManager.java:436) at at com.example.HttpClient.work(DatabaseConnector.java:494) at at com.example.ShippingService.window(OrderModule.java:378) at at com.example.CacheManager.almost(DatabaseConnector.java:326) at at com.example.DatabaseConnector.couple(AuthModule.java:13) at at com.example.PaymentModule.wrong(HttpClient.java:244)."""

try:
    model = AutoModelForCausalLM.from_pretrained(
        MERGED_MODEL_PATH,
        low_cpu_mem_usage= True,
        return_dict = True,
        torch_dtype = torch.float16,
        device_map = "auto"
    )
    print("AutoModelForCausalLM loaded successfully.")
    print("Loading AutoTokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MERGED_MODEL_PATH)
    print("AutoTokenizer loaded successfully.")
except Exception as e:
    print("ERROR LOADING MODEL OR TOKENIZER...CHECK PATH")
    traceback.print_exc()

if tokenizer is None:
    print("error loading tokenizer")
    exit(1)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

prompt = SAMPLE_LOG + "\n"

inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=True).to(model.device)

with torch.no_grad():
    output_tokens = model.generate(
        **inputs,
        max_new_tokens=60,
        temperature=0.3,
        do_sample=True,
        top_p=0.9,
        top_k=30,
        eos_token_id = tokenizer.eos_token_id,
        pad_token_id = tokenizer.pad_token_id,
        num_return_sequences = 1
    )

generated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
print(f"Generated Text: {generated_text}")
print("END OF GENERATED TEXT")
#summary_start_index = generated_text.find(SAMPLE_LOG + "\n")

# prompt_end_index = generated_text.rfind(
summary_start_index = len(SAMPLE_LOG) + 1
summary = ""

if "PM" in generated_text:
    summary_end_index = generated_text.rfind("PM") + len("PM")
elif "AM" in generated_text:
    summary_end_index = generated_text.rfind("AM") + len("AM")

if summary_end_index != -1 and summary_end_index > summary_start_index:
    summary = generated_text[len(SAMPLE_LOG)+1:summary_end_index].strip()
else:
    prompt_end_index = generated_text.find(SAMPLE_LOG + "\n")
    if prompt_end_index != -1:
        summary = generated_text[prompt_end_index + len(SAMPLE_LOG + "\n"):].strip()
    else:
        summary = generated_text.strip()

print(summary)