File size: 1,876 Bytes
6d91ffe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import json
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel

SYSTEM_PROMPT = (
    "You are an advanced AI model specialized in extracting aspects and determining their sentiment polarity from customer reviews.\n\n"
    "Instructions:\n"
    "1. Extract only the aspects (nouns) mentioned in the review.\n"
    "2. Assign a sentiment to each aspect: \"positive\", \"negative\", or \"neutral\".\n"
    "3. Return aspects in the same language as they appear.\n"
    "4. An aspect must be a noun that refers to a specific item or service the user described.\n"
    "5. Ignore adjectives, general ideas, and vague topics.\n"
    "6. Do NOT translate, explain, or add extra text.\n"
    "7. The output must be just a valid JSON list with 'aspect' and 'sentiment'. Start with `[` and stop at `]`.\n"
    "8. Do NOT output the instructions, review, or any text — only one output JSON list.\n"
    "9. Just one output and one review."
)



def infer_t5_prompt(review_text, tokenizer, peft_model):
    prompt = SYSTEM_PROMPT + f"\n\nReview: {review_text}"

    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(peft_model.device)

    with torch.no_grad():
        outputs = peft_model.generate(
            **inputs,
            max_new_tokens=256,
            num_beams=4,
            do_sample=False,
            temperature=0.0,
            early_stopping=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    decoded = tokenizer.decode(
        outputs[0],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    ).strip()

    decoded = decoded.replace('<extra_id_0>', '').replace('</s>', '').strip()

    try:
        return json.loads(decoded)
    except json.JSONDecodeError:
        return decoded