File size: 1,359 Bytes
9d76abe
 
5a35c60
 
9d76abe
 
 
54368bc
9d76abe
 
 
 
 
 
 
 
 
c7b328e
9d76abe
 
 
 
 
 
 
 
 
 
c7b328e
9d76abe
 
c7b328e
9d76abe
 
 
f325d03
9d76abe
05aab8c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
import pandas as pd
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained("Celestinian/TopicGPT")
tokenizer = AutoTokenizer.from_pretrained("Celestinian/TopicGPT")


def get_topic(prompt, temperature=0.01, max_size=20):
    input_ids = tokenizer.encode("#CONTEXT# " + prompt + " #TOPIC#", return_tensors='pt')
    input_ids = input_ids.to(device)
    model.eval()
    model.to(device)

    output_tokens = []
    eos_token_id = tokenizer.encode('#')[0]
    print(input_ids)
    for _ in range(max_size):
        with torch.no_grad():
            outputs = model(input_ids)
        logits = outputs.logits[:, -1, :] / temperature
        next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
        if next_token.item() == eos_token_id:
            break
        input_ids = torch.cat((input_ids, next_token), dim=-1)
        output_tokens.append(next_token.item())

    print(output_tokens)
    output = tokenizer.decode(output_tokens)
    clean_output = output.replace('\n', '')
    print(output)
    return clean_output


text = st.text_area("Review", "Trustpilot is the one and only you can trust and share your thoughts with others.Love it !")
outputs = get_topic(text)
st.write(outputs)