| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import torch |
| import streamlit as st |
| import pandas as pd |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = AutoModelForCausalLM.from_pretrained("Celestinian/TopicGPT") |
| tokenizer = AutoTokenizer.from_pretrained("Celestinian/TopicGPT") |
|
|
|
|
| def get_topic(prompt, temperature=0.01, max_size=20): |
| input_ids = tokenizer.encode("#CONTEXT# " + prompt + " #TOPIC#", return_tensors='pt') |
| input_ids = input_ids.to(device) |
| model.eval() |
| model.to(device) |
|
|
| output_tokens = [] |
| eos_token_id = tokenizer.encode('#')[0] |
| print(input_ids) |
| for _ in range(max_size): |
| with torch.no_grad(): |
| outputs = model(input_ids) |
| logits = outputs.logits[:, -1, :] / temperature |
| next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1) |
| if next_token.item() == eos_token_id: |
| break |
| input_ids = torch.cat((input_ids, next_token), dim=-1) |
| output_tokens.append(next_token.item()) |
|
|
| print(output_tokens) |
| output = tokenizer.decode(output_tokens) |
| clean_output = output.replace('\n', '') |
| print(output) |
| return clean_output |
|
|
|
|
| text = st.text_area("Review", "Trustpilot is the one and only you can trust and share your thoughts with others.Love it !") |
| outputs = get_topic(text) |
| st.write(outputs) |