Akshitha1 commited on
Commit
9736d99
·
verified ·
1 Parent(s): 60fbd3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -41
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.optim as optim
4
  import pandas as pd
5
- from torch.utils.data import Dataset, DataLoader
6
- from flask import Flask, request, jsonify
7
  from sklearn.model_selection import train_test_split
 
 
 
8
  import os
9
 
10
  # Load data
@@ -41,7 +42,7 @@ train_data, test_data = train_test_split(df, test_size=0.2, random_state=42)
41
  tokenizer = ScratchTokenizer()
42
  tokenizer.build_vocab(train_data["instruction"].tolist() + train_data["response"].tolist())
43
 
44
- # Dataset Class
45
  class TextDataset(Dataset):
46
  def __init__(self, data, tokenizer, max_len=200):
47
  self.data = data
@@ -92,49 +93,32 @@ def load_model(model, path="gpt_model.pth"):
92
  load_model(model)
93
 
94
  # Generate Response
95
- # def generate_response(model, query, max_length=200):
96
- # model.eval()
97
- # src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
98
- # tgt = torch.tensor([[1]]).to(device) # <SOS>
99
- # for _ in range(max_length):
100
- # output = model(src, tgt)
101
- # next_word = output.argmax(-1)[:, -1].unsqueeze(1)
102
- # tgt = torch.cat([tgt, next_word], dim=1)
103
- # if next_word.item() == 2: # <EOS>
104
- # break
105
- # return tokenizer.decode(tgt.squeeze(0).tolist())
106
-
107
  def generate_response(model, query, max_length=200):
108
  model.eval()
109
- with torch.no_grad(): # Disable gradient tracking
110
- src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
111
- tgt = torch.tensor([[1]]).to(device) # <SOS>
112
-
113
- for _ in range(max_length):
114
- output = model(src, tgt)
115
- next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
116
- tgt = torch.cat([tgt, next_token], dim=1)
117
- if next_token.item() == 2: # <EOS>
118
- break
119
-
120
  return tokenizer.decode(tgt.squeeze(0).tolist())
121
 
 
 
122
 
123
- # Flask App
124
- app = Flask(__name__)
125
 
126
- @app.route("/")
127
- def home():
128
  return {"message": "Transformer-based Response Generator API is running!"}
129
 
130
- @app.route("/query", methods=["POST"])
131
- def query_model():
132
- data = request.get_json()
133
- query = data.get("query", "")
134
- if not query:
135
- return jsonify({"error": "Query cannot be empty"}), 400
136
- response = generate_response(model, query)
137
- return jsonify({"query": query, "response": response})
138
-
139
- # DO NOT ADD app.run()
140
 
 
1
  import torch
2
  import torch.nn as nn
 
3
  import pandas as pd
4
+ from torch.utils.data import Dataset
 
5
  from sklearn.model_selection import train_test_split
6
+ from fastapi import FastAPI
7
+ from pydantic import BaseModel
8
+ from fastapi.responses import JSONResponse
9
  import os
10
 
11
  # Load data
 
42
  tokenizer = ScratchTokenizer()
43
  tokenizer.build_vocab(train_data["instruction"].tolist() + train_data["response"].tolist())
44
 
45
+ # Dataset Class (not used in inference but useful for training)
46
  class TextDataset(Dataset):
47
  def __init__(self, data, tokenizer, max_len=200):
48
  self.data = data
 
93
  load_model(model)
94
 
95
  # Generate Response
 
 
 
 
 
 
 
 
 
 
 
 
96
  def generate_response(model, query, max_length=200):
97
  model.eval()
98
+ src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
99
+ tgt = torch.tensor([[1]]).to(device) # <SOS>
100
+ for _ in range(max_length):
101
+ output = model(src, tgt)
102
+ next_word = output.argmax(-1)[:, -1].unsqueeze(1)
103
+ tgt = torch.cat([tgt, next_word], dim=1)
104
+ if next_word.item() == 2: # <EOS>
105
+ break
 
 
 
106
  return tokenizer.decode(tgt.squeeze(0).tolist())
107
 
108
+ # FastAPI app
109
+ app = FastAPI()
110
 
111
+ class Query(BaseModel):
112
+ query: str
113
 
114
+ @app.get("/")
115
+ async def root():
116
  return {"message": "Transformer-based Response Generator API is running!"}
117
 
118
+ @app.post("/query")
119
+ async def query_model(query: Query):
120
+ if not query.query.strip():
121
+ return JSONResponse(status_code=400, content={"error": "Query cannot be empty"})
122
+ response = generate_response(model, query.query)
123
+ return {"query": query.query, "response": response}
 
 
 
 
124