{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "9699ec04-d6f5-4bba-ac94-f33081929626", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 15, "id": "afa7acd8-8d00-4c86-b98c-4a540313149d", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n", "import torch\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 16, "id": "3d1d289d-9156-4f08-8f8f-37c101caf476", "metadata": {}, "outputs": [], "source": [ "model_name = \"pdp19/ipl_match_winner_predictor\"" ] }, { "cell_type": "code", "execution_count": 17, "id": "099f8842-46ad-440d-bdfe-efa5d79dbef3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cpu')" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device" ] }, { "cell_type": "code", "execution_count": 18, "id": "ce56f89d-4a4e-4ebf-be68-67fedc0de210", "metadata": {}, "outputs": [], "source": [ "#Do not change these labels\n", "index_to_label = {0: 'Sunrisers Hyderabad (SRH)',\n", " 1: 'Chennai Super Kings (CSK)',\n", " 2: 'Delhi Capitals (DC)',\n", " 3: 'Mumbai Indians (MI)',\n", " 4: 'Kings XI Punjab (KXIP)',\n", " 5: 'Royal Challengers Bangalore (RCB)',\n", " 6: 'Gujarat Titans (GT)',\n", " 7: 'Kolkata Knight Riders (KKR)',\n", " 8: 'Rajasthan Royals (RR)',\n", " 9: 'Lucknow Super Giants (LSG)'}" ] }, { "cell_type": "code", "execution_count": 19, "id": "8ef7a4c6-2463-4397-8a63-dd28a1db1eb8", "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = AutoModelForSequenceClassification.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": 20, "id": "93ecd2be-220d-4a7d-b15f-d6c80209a46c", "metadata": {}, "outputs": [], "source": [ "def get_winner(input_query):\n", " \n", " inputs = tokenizer(input_query,padding = True, truncation = True, return_tensors='pt').to(device)\n", " outputs = model(**inputs)\n", " predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)\n", " predictions = predictions.cpu().detach().numpy()\n", "\n", " predicted_index = np.argmax(predictions)\n", " predicted_label = index_to_label[predicted_index]\n", "\n", " return \"Winner of the match for your input query is: \" + predicted_label" ] }, { "cell_type": "code", "execution_count": 21, "id": "b4d1b206-c59b-44ed-9dd9-be44c52dd5d1", "metadata": {}, "outputs": [], "source": [ "question = \"who is going to win if match is played beteen KKR and DC\"\n", "results = get_winner(question)" ] }, { "cell_type": "code", "execution_count": 22, "id": "391cb5f2-25a9-4db4-be02-05ef1093f633", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Winner of the match for your input query is: Delhi Capitals (DC)'" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results" ] }, { "cell_type": "code", "execution_count": null, "id": "22c76ea9-fd28-4b05-b32c-d2769f745b64", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }