Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,7 +3,6 @@ import torch
|
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
from fastapi import FastAPI
|
| 5 |
from pydantic import BaseModel
|
| 6 |
-
import uvicorn
|
| 7 |
|
| 8 |
# Load model and tokenizer
|
| 9 |
model_name = "roneneldan/TinyStories-33M"
|
|
@@ -29,12 +28,14 @@ def generate_story(prompt):
|
|
| 29 |
|
| 30 |
return generated_text
|
| 31 |
|
| 32 |
-
# Set up FastAPI
|
| 33 |
app = FastAPI()
|
| 34 |
|
|
|
|
| 35 |
class StoryRequest(BaseModel):
|
| 36 |
prompt: str
|
| 37 |
|
|
|
|
| 38 |
@app.post("/generate")
|
| 39 |
async def generate(request: StoryRequest):
|
| 40 |
generated_text = generate_story(request.prompt)
|
|
@@ -49,9 +50,10 @@ iface = gr.Interface(
|
|
| 49 |
description="Generate short stories using the TinyStories-33M model."
|
| 50 |
)
|
| 51 |
|
| 52 |
-
# Mount Gradio
|
| 53 |
app = gr.mount_gradio_app(app, iface, path="/")
|
| 54 |
|
| 55 |
-
#
|
| 56 |
if __name__ == "__main__":
|
|
|
|
| 57 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
from fastapi import FastAPI
|
| 5 |
from pydantic import BaseModel
|
|
|
|
| 6 |
|
| 7 |
# Load model and tokenizer
|
| 8 |
model_name = "roneneldan/TinyStories-33M"
|
|
|
|
| 28 |
|
| 29 |
return generated_text
|
| 30 |
|
| 31 |
+
# Set up FastAPI app
|
| 32 |
app = FastAPI()
|
| 33 |
|
| 34 |
+
# FastAPI request model
|
| 35 |
class StoryRequest(BaseModel):
|
| 36 |
prompt: str
|
| 37 |
|
| 38 |
+
# FastAPI route to generate story
|
| 39 |
@app.post("/generate")
|
| 40 |
async def generate(request: StoryRequest):
|
| 41 |
generated_text = generate_story(request.prompt)
|
|
|
|
| 50 |
description="Generate short stories using the TinyStories-33M model."
|
| 51 |
)
|
| 52 |
|
| 53 |
+
# Mount Gradio app to FastAPI
|
| 54 |
app = gr.mount_gradio_app(app, iface, path="/")
|
| 55 |
|
| 56 |
+
# For local testing
|
| 57 |
if __name__ == "__main__":
|
| 58 |
+
import uvicorn
|
| 59 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|