Spaces:
Sleeping
Sleeping
| import gradio | |
| import wandb | |
| import torch | |
| from transformers import GPT2Tokenizer,GPT2LMHeadModel | |
| from peft import PeftModel | |
| import os | |
| import re | |
| def clean_text(text): | |
| # Lowercase the text | |
| text = text.lower() | |
| # Remove special characters | |
| text = re.sub(r'\W', ' ', text) | |
| # Remove extra white spaces | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| os.environ["WANDB_API_KEY"] = "d2ad0a7285379c0808ca816971d965fc242d0b5e" | |
| wandb.login() | |
| run = wandb.init(project="Email_subject_gen", job_type="model_loading") | |
| artifact = run.use_artifact('Email_subject_gen/final_model:v0') | |
| artifact_dir = artifact.download() | |
| #tokenizer= GPT2Tokenizer.from_pretrained(artifact_dir) | |
| MODEL_KEY = 'olm/olm-gpt2-dec-2022' | |
| tokenizer= GPT2Tokenizer.from_pretrained(MODEL_KEY) | |
| tokenizer.add_special_tokens({'pad_token':'{PAD}'}) | |
| model = GPT2LMHeadModel.from_pretrained(MODEL_KEY) | |
| model.resize_token_embeddings(len(tokenizer)) | |
| model.config.dropout = 0.1 # Set dropout rate | |
| model.config.attention_dropout = 0.1 | |
| model = PeftModel.from_pretrained(model, artifact_dir) | |
| def generateSubject(email): | |
| clean_text(email) | |
| email = "<email>" + clean_text(email) + "<subject>" | |
| prompts = list() | |
| prompts.append(email) | |
| tokenizer.padding_side='left' | |
| prompts_batch_ids = tokenizer(prompts, | |
| padding=True, truncation=True, return_tensors='pt').to(model.device) | |
| output_ids = model.generate( | |
| **prompts_batch_ids, max_new_tokens=10, | |
| pad_token_id=tokenizer.pad_token_id) | |
| outputs_batch = [seq.split('<subject>')[1] for seq in | |
| tokenizer.batch_decode(output_ids, skip_special_tokens=True)] | |
| tokenizer.padding_side='right' | |
| print(outputs_batch) | |
| return outputs_batch[0] | |
| def predict(name): | |
| return "Hello " + name + "!!" | |
| iface = gradio.Interface(fn=generateSubject, inputs="text", outputs="text") | |
| iface.launch() | |