Akshitha1 commited on
Commit
69a52fc
·
verified ·
1 Parent(s): 10f7d8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -25
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import torch
2
  import torch.nn as nn
 
3
  import pandas as pd
4
- from torch.utils.data import Dataset
 
5
  from sklearn.model_selection import train_test_split
6
- from fastapi import FastAPI
7
- from pydantic import BaseModel
8
- from fastapi.responses import JSONResponse
9
  import os
10
 
11
  # Load data
@@ -42,7 +41,7 @@ train_data, test_data = train_test_split(df, test_size=0.2, random_state=42)
42
  tokenizer = ScratchTokenizer()
43
  tokenizer.build_vocab(train_data["instruction"].tolist() + train_data["response"].tolist())
44
 
45
- # Dataset Class (not used in inference but useful for training)
46
  class TextDataset(Dataset):
47
  def __init__(self, data, tokenizer, max_len=200):
48
  self.data = data
@@ -93,31 +92,49 @@ def load_model(model, path="gpt_model.pth"):
93
  load_model(model)
94
 
95
  # Generate Response
 
 
 
 
 
 
 
 
 
 
 
 
96
  def generate_response(model, query, max_length=200):
97
  model.eval()
98
- src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
99
- tgt = torch.tensor([[1]]).to(device) # <SOS>
100
- for _ in range(max_length):
101
- output = model(src, tgt)
102
- next_word = output.argmax(-1)[:, -1].unsqueeze(1)
103
- tgt = torch.cat([tgt, next_word], dim=1)
104
- if next_word.item() == 2: # <EOS>
105
- break
 
 
 
106
  return tokenizer.decode(tgt.squeeze(0).tolist())
107
 
108
- # FastAPI app
109
- app = FastAPI()
110
 
111
- class Query(BaseModel):
112
- query: str
113
 
114
- @app.get("/")
115
- async def root():
116
  return {"message": "Transformer-based Response Generator API is running!"}
117
 
118
- @app.post("/query")
119
- async def query_model(query: Query):
120
- if not query.query.strip():
121
- return JSONResponse(status_code=400, content={"error": "Query cannot be empty"})
122
- response = generate_response(model, query.query)
123
- return {"query": query.query, "response": response}
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.optim as optim
4
  import pandas as pd
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from flask import Flask, request, jsonify
7
  from sklearn.model_selection import train_test_split
 
 
 
8
  import os
9
 
10
  # Load data
 
41
  tokenizer = ScratchTokenizer()
42
  tokenizer.build_vocab(train_data["instruction"].tolist() + train_data["response"].tolist())
43
 
44
+ # Dataset Class
45
  class TextDataset(Dataset):
46
  def __init__(self, data, tokenizer, max_len=200):
47
  self.data = data
 
92
  load_model(model)
93
 
94
  # Generate Response
95
+ # def generate_response(model, query, max_length=200):
96
+ # model.eval()
97
+ # src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
98
+ # tgt = torch.tensor([[1]]).to(device) # <SOS>
99
+ # for _ in range(max_length):
100
+ # output = model(src, tgt)
101
+ # next_word = output.argmax(-1)[:, -1].unsqueeze(1)
102
+ # tgt = torch.cat([tgt, next_word], dim=1)
103
+ # if next_word.item() == 2: # <EOS>
104
+ # break
105
+ # return tokenizer.decode(tgt.squeeze(0).tolist())
106
+
107
  def generate_response(model, query, max_length=200):
108
  model.eval()
109
+ with torch.no_grad(): # Disable gradient tracking
110
+ src = torch.tensor(tokenizer.encode(query)).unsqueeze(0).to(device)
111
+ tgt = torch.tensor([[1]]).to(device) # <SOS>
112
+
113
+ for _ in range(max_length):
114
+ output = model(src, tgt)
115
+ next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
116
+ tgt = torch.cat([tgt, next_token], dim=1)
117
+ if next_token.item() == 2: # <EOS>
118
+ break
119
+
120
  return tokenizer.decode(tgt.squeeze(0).tolist())
121
 
 
 
122
 
123
+ # Flask App
124
+ app = Flask(__name__)
125
 
126
+ @app.route("/")
127
+ def home():
128
  return {"message": "Transformer-based Response Generator API is running!"}
129
 
130
+ @app.route("/query", methods=["POST"])
131
+ def query_model():
132
+ data = request.get_json()
133
+ query = data.get("query", "")
134
+ if not query:
135
+ return jsonify({"error": "Query cannot be empty"}), 400
136
+ response = generate_response(model, query)
137
+ return jsonify({"query": query, "response": response})
138
+
139
+ # DO NOT ADD app.run()
140
+