Update app.py
Browse files
app.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
from
|
| 2 |
-
import
|
| 3 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 4 |
|
| 5 |
-
app =
|
| 6 |
|
| 7 |
# Load the LSTM-based language model
|
| 8 |
-
model_path = "
|
| 9 |
-
tokenizer = GPT2Tokenizer.from_pretrained("
|
| 10 |
-
model = GPT2LMHeadModel.from_pretrained("
|
| 11 |
model.load_state_dict(torch.load(model_path))
|
| 12 |
|
| 13 |
# Set the model to evaluation mode
|
|
@@ -20,16 +20,54 @@ def generate_text(prompt):
|
|
| 20 |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 21 |
return generated_text
|
| 22 |
|
| 23 |
-
@app.
|
| 24 |
-
def home():
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
if __name__ == "__main__":
|
| 35 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Request, Form
|
| 2 |
+
from fastapi.responses import HTMLResponse
|
| 3 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 4 |
|
| 5 |
+
app = FastAPI()
|
| 6 |
|
| 7 |
# Load the LSTM-based language model
|
| 8 |
+
model_path = "your_model.pth" # Replace with your model path
|
| 9 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 10 |
+
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
| 11 |
model.load_state_dict(torch.load(model_path))
|
| 12 |
|
| 13 |
# Set the model to evaluation mode
|
|
|
|
| 20 |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 21 |
return generated_text
|
| 22 |
|
| 23 |
+
@app.get("/", response_class=HTMLResponse)
|
| 24 |
+
async def home(request: Request):
|
| 25 |
+
html_content = """
|
| 26 |
+
<!DOCTYPE html>
|
| 27 |
+
<html lang="en">
|
| 28 |
+
<head>
|
| 29 |
+
<meta charset="UTF-8">
|
| 30 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 31 |
+
<title>LSTM Text Generation</title>
|
| 32 |
+
</head>
|
| 33 |
+
<body>
|
| 34 |
+
<h1>LSTM Text Generation</h1>
|
| 35 |
+
<form id="text-form">
|
| 36 |
+
<label for="user-input">Enter your input:</label><br>
|
| 37 |
+
<textarea id="user-input" name="user-input" rows="4" cols="50"></textarea><br>
|
| 38 |
+
<button type="submit">Generate Text</button>
|
| 39 |
+
</form>
|
| 40 |
+
<div id="output"></div>
|
| 41 |
|
| 42 |
+
<script>
|
| 43 |
+
document.getElementById("text-form").addEventListener("submit", function(event) {
|
| 44 |
+
event.preventDefault();
|
| 45 |
+
var userInput = document.getElementById("user-input").value;
|
| 46 |
+
|
| 47 |
+
fetch("/generate", {
|
| 48 |
+
method: "POST",
|
| 49 |
+
headers: {
|
| 50 |
+
"Content-Type": "application/json"
|
| 51 |
+
},
|
| 52 |
+
body: JSON.stringify({ input_text: userInput })
|
| 53 |
+
})
|
| 54 |
+
.then(response => response.json())
|
| 55 |
+
.then(data => {
|
| 56 |
+
document.getElementById("output").innerText = data.generated_text;
|
| 57 |
+
});
|
| 58 |
+
});
|
| 59 |
+
</script>
|
| 60 |
+
</body>
|
| 61 |
+
</html>
|
| 62 |
+
"""
|
| 63 |
+
return HTMLResponse(content=html_content, status_code=200)
|
| 64 |
+
|
| 65 |
+
@app.post("/generate")
|
| 66 |
+
async def generate(request: Request, input_text: str = Form(...)):
|
| 67 |
+
generated_text = generate_text(input_text)
|
| 68 |
+
return {"generated_text": generated_text}
|
| 69 |
|
| 70 |
if __name__ == "__main__":
|
| 71 |
+
import uvicorn
|
| 72 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 73 |
+
|