aryo100 commited on
Commit
f8184cb
·
1 Parent(s): 46ce46c

update app

Browse files
Files changed (1) hide show
  1. app.py +27 -20
app.py CHANGED
@@ -1,28 +1,35 @@
1
  from fastapi import FastAPI
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import torch
 
4
 
5
- app = FastAPI()
 
6
 
7
- # Load Qwen-7B dengan trust_remote_code
8
- model_name = "Qwen/Qwen-1_8B"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
10
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", trust_remote_code=True)
 
 
 
 
 
11
 
12
- @app.get("/")
13
- def home():
14
- return {"status": "ok", "message": "Qwen-7B API is running!"}
 
15
 
16
  @app.post("/chat")
17
- async def chat(prompt: str):
18
- inputs = tokenizer(prompt, return_tensors="pt")
19
- with torch.no_grad():
20
- outputs = model.generate(
21
- **inputs,
22
- max_new_tokens=200,
23
- do_sample=True,
24
- temperature=0.7,
25
- top_p=0.9
26
- )
27
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
- return {"response": text}
 
1
  from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch, os
5
+ import uvicorn
6
 
7
+ # --- Inisialisasi FastAPI ---
8
+ app = FastAPI(title="Qwen Chat API")
9
 
10
+ # --- Load model & tokenizer ---
11
+ model_name = "Qwen/Qwen-1_8B-Chat" # ganti sesuai model yang muat di RAM
12
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_name,
15
+ trust_remote_code=True,
16
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
17
+ device_map="auto" if torch.cuda.is_available() else "cpu"
18
+ )
19
 
20
+ # --- Request & Response schema ---
21
+ class ChatRequest(BaseModel):
22
+ prompt: str
23
+ max_new_tokens: int = 128
24
 
25
  @app.post("/chat")
26
+ def chat(req: ChatRequest):
27
+ inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
28
+ outputs = model.generate(**inputs, max_new_tokens=req.max_new_tokens)
29
+ reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
30
+ return {"reply": reply}
31
+
32
+ # --- Entrypoint ---
33
+ if __name__ == "__main__":
34
+ port = int(os.environ.get("PORT", 7860)) # HF Spaces default port
35
+ uvicorn.run("app:app", host="0.0.0.0", port=port)