test-st / app.py
bkv-ata's picture
Update app.py
c7b328e
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
import pandas as pd
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained("Celestinian/TopicGPT")
tokenizer = AutoTokenizer.from_pretrained("Celestinian/TopicGPT")
def get_topic(prompt, temperature=0.01, max_size=20):
input_ids = tokenizer.encode("#CONTEXT# " + prompt + " #TOPIC#", return_tensors='pt')
input_ids = input_ids.to(device)
model.eval()
model.to(device)
output_tokens = []
eos_token_id = tokenizer.encode('#')[0]
print(input_ids)
for _ in range(max_size):
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits[:, -1, :] / temperature
next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
if next_token.item() == eos_token_id:
break
input_ids = torch.cat((input_ids, next_token), dim=-1)
output_tokens.append(next_token.item())
print(output_tokens)
output = tokenizer.decode(output_tokens)
clean_output = output.replace('\n', '')
print(output)
return clean_output
text = st.text_area("Review", "Trustpilot is the one and only you can trust and share your thoughts with others.Love it !")
outputs = get_topic(text)
st.write(outputs)