kouki321 commited on
Commit
b4e99db
·
verified ·
1 Parent(s): d5516d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -9
app.py CHANGED
@@ -1,22 +1,44 @@
1
- #from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
  from fastapi import FastAPI, Request
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
4
 
5
  app = FastAPI()
6
- model_id = "google/flan-t5-small"
 
 
7
 
8
  #"unsloth/mistral-7b-v0.2-bnb-4bit"
9
  #deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
10
 
11
-
12
  tokenizer = AutoTokenizer.from_pretrained(model_id)
13
- model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
14
- generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
 
 
 
 
 
15
 
16
  @app.post("/generate")
17
  async def generate(request: Request):
18
  data = await request.json()
19
- prompt = data.get("prompt", "")
20
- result = generator(prompt, max_new_tokens=100)[0]["generated_text"]
21
- return {"output": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
 
1
  from fastapi import FastAPI, Request
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
 
5
  app = FastAPI()
6
+
7
+ model_id = "google/flan-t5-small" # Replace with your model here
8
+
9
 
10
  #"unsloth/mistral-7b-v0.2-bnb-4bit"
11
  #deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
12
 
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_id)
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_id,
16
+ torch_dtype=torch.float16,
17
+ device_map="auto",
18
+ )
19
+
20
+ cache = {}
21
 
22
  @app.post("/generate")
23
  async def generate(request: Request):
24
  data = await request.json()
25
+ prompt = data.get("prompt", "").strip()
26
+
27
+ if prompt in cache:
28
+ return {"output": cache[prompt], "cached": True}
29
+
30
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
31
+ outputs = model.generate(
32
+ inputs.input_ids,
33
+ max_new_tokens=100,
34
+ use_cache=True,
35
+ do_sample=True,
36
+ top_p=0.95,
37
+ top_k=50,
38
+ temperature=0.7,
39
+ )
40
+
41
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+ cache[prompt] = generated_text
43
 
44
+ return {"output": generated_text, "cached": False}