LucianStorm commited on
Commit
c05461e
·
verified ·
1 Parent(s): 005eafc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -11
app.py CHANGED
@@ -1,16 +1,60 @@
1
- from fastapi import FastAPI
2
- import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- app = FastAPI()
 
6
 
7
- MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
8
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto")
10
 
11
- @app.get("/generate")
12
- def generate(prompt: str):
13
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
14
- output = model.generate(**inputs, max_length=200)
15
- return {"response": tokenizer.decode(output[0], skip_special_tokens=True)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+ from typing import List
6
+ import uvicorn
7
+
8
+ app = FastAPI(title="TinyLlama Fitness Bot")
9
+
10
+ # Initialize model and tokenizer
11
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_name,
15
+ torch_dtype=torch.float32, # Use float32 for CPU
16
+ low_cpu_mem_usage=True
17
+ )
18
+
19
+ class Query(BaseModel):
20
+ prompt: str
21
+ max_length: int = 256
22
+ temperature: float = 0.7
23
 
24
+ class Response(BaseModel):
25
+ response: str
26
 
27
+ @app.get("/")
28
+ def read_root():
29
+ return {"message": "TinyLlama Fitness Bot API is running!"}
30
 
31
+ @app.post("/chat", response_model=Response)
32
+ async def chat(query: Query):
33
+ try:
34
+ # Format prompt for TinyLlama
35
+ system_prompt = """You are a knowledgeable fitness and nutrition assistant.
36
+ Provide helpful, science-based advice about workouts, nutrition, and healthy lifestyle choices."""
37
+
38
+ formatted_prompt = f"<|system|>{system_prompt}</s><|user|>{query.prompt}</s><|assistant|>"
39
+
40
+ inputs = tokenizer(formatted_prompt, return_tensors="pt")
41
+
42
+ with torch.no_grad():
43
+ outputs = model.generate(
44
+ inputs["input_ids"],
45
+ max_new_tokens=query.max_length,
46
+ temperature=query.temperature,
47
+ top_p=0.9,
48
+ do_sample=True,
49
+ )
50
+
51
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+ response = response.split("<|assistant|>")[-1].strip()
53
+
54
+ return Response(response=response)
55
+
56
+ except Exception as e:
57
+ raise HTTPException(status_code=500, detail=str(e))
58
 
59
+ if __name__ == "__main__":
60
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)