kouki321 commited on
Commit
93e7ae8
·
verified ·
1 Parent(s): 7ddc2f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -32
app.py CHANGED
@@ -1,38 +1,24 @@
1
- from fastapi import FastAPI, Request
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
4
-
5
- app = FastAPI()
6
-
7
- model_id = "google/flan-t5-small" # Replace with your model here
8
-
9
-
10
  #"unsloth/mistral-7b-v0.2-bnb-4bit"
11
  #deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
12
 
13
- tokenizer = AutoTokenizer.from_pretrained(model_id)
14
- model = AutoModelForCausalLM.from_pretrained(
15
- model_id,
16
- torch_dtype=torch.float16,
17
- device_map="auto",
18
- )
19
-
20
-
21
- @app.post("/generate")
22
- async def generate(request: Request):
23
- data = await request.json()
24
- prompt = data.get("prompt", "").strip()
25
 
26
-
27
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
28
- outputs = model.generate(
29
- inputs.input_ids,
30
- max_new_tokens=100,
31
- use_cache=True,
32
-
33
- temperature=0.7,
34
- )
35
 
36
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
37
 
38
- return {"output": generated_text}
 
 
 
 
1
+ import os
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ # Example: using Google Flan-T5 small
4
+ model_id = "google/flan-t5-small"
 
 
 
 
 
5
  #"unsloth/mistral-7b-v0.2-bnb-4bit"
6
  #deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
7
 
8
+ # Ensure cache directories exist
9
+ os.makedirs("/app/cache", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Load tokenizer & model using the custom cache path
12
+ tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="/app/cache")
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id, cache_dir="/app/cache")
 
 
 
 
 
 
14
 
15
+ # Example simple inference
16
+ def generate(text):
17
+ inputs = tokenizer(text, return_tensors="pt")
18
+ outputs = model.generate(**inputs)
19
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
20
 
21
+ if __name__ == "__main__":
22
+ # Simple CLI test
23
+ prompt = "Translate English to French: Hello, how are you?"
24
+ print(generate(prompt))