Batrdj commited on
Commit
f7cfbba
·
verified ·
1 Parent(s): 2f1d85a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -9
app.py CHANGED
@@ -1,30 +1,52 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
 
5
  app = FastAPI()
6
 
7
- # Ultra-tiny model (SAFE for free CPU)
8
  MODEL_NAME = "sshleifer/tiny-gpt2"
9
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
 
 
 
12
 
 
13
  class Prompt(BaseModel):
14
  message: str
15
 
 
16
  @app.get("/")
17
  def root():
18
  return {"status": "TinyLLM API is running"}
19
 
 
20
  @app.post("/chat")
21
  def chat(prompt: Prompt):
22
- inputs = tokenizer(prompt.message, return_tensors="pt")
23
- outputs = model.generate(
24
- **inputs,
25
- max_new_tokens=50,
26
- do_sample=True,
27
- temperature=0.7
28
  )
 
 
 
 
 
 
 
 
 
 
29
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
- return {"response": response}
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
 
6
  app = FastAPI()
7
 
8
+ # Ultra-tiny model (SAFE for free CPU)
9
  MODEL_NAME = "sshleifer/tiny-gpt2"
10
 
11
+ # Load tokenizer & model once at startup
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ MODEL_NAME,
15
+ torch_dtype=torch.float32
16
+ )
17
+ model.eval()
18
 
19
+ # Request schema
20
  class Prompt(BaseModel):
21
  message: str
22
 
23
+ # Health check
24
  @app.get("/")
25
  def root():
26
  return {"status": "TinyLLM API is running"}
27
 
28
+ # Chat endpoint
29
  @app.post("/chat")
30
  def chat(prompt: Prompt):
31
+ inputs = tokenizer(
32
+ prompt.message,
33
+ return_tensors="pt",
34
+ truncation=True,
35
+ max_length=128
 
36
  )
37
+
38
+ with torch.no_grad():
39
+ outputs = model.generate(
40
+ **inputs,
41
+ max_new_tokens=50,
42
+ do_sample=True,
43
+ temperature=0.7,
44
+ top_p=0.9
45
+ )
46
+
47
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+
49
+ return {
50
+ "input": prompt.message,
51
+ "response": response
52
+ }