Khushali-shah's picture
Update app.py
a16bb2d verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# model = AutoModelForCausalLM.from_pretrained("hiraltalsaniya/phi2-task-classification-demo", torch_dtype="auto", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("hiraltalsaniya/phi2-task-classification-demo", torch_dtype=torch.float32, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("hiraltalsaniya/phi2-task-classification-demo", trust_remote_code=True)
def classify(text):
prompt_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction: Categorize the IT task description into one of the 6 categories:
Client Meeting\nDevelopment\nInternal Meeting\nInterview\nLearning\nReview
### User Input:
{}
### Response:
"""
user_input = f"{text}"
prompt = prompt_template.format(user_input)
inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False)
outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]
return text
# prompt_template = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
# ### Instruction: Categorize the IT task description into one of the 6 categories:
# Client Meeting
# Development
# Internal Meeting
# Interview
# Learning
# Review
# ### User Input:
# {}
# ### Response:
# """
# user_input = f"{text}"
# prompt = prompt_template.format(user_input)
# # The rest of your code remains unchanged
# inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=True)
# outputs = model.generate(**inputs, max_length=200, pad_token_id=tokenizer.eos_token_id)
# response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# # print(response_text)
# return response_text
description = "This AI model is trained to classify texts."
title = "Classify Your Texts"
iface = gr.Interface(fn=classify,
inputs="text",
outputs="text",
title=title,
description=description)
iface.launch()
print("Interface URL:", iface.share())