bsny commited on
Commit
d6e6a43
·
verified ·
1 Parent(s): c41dab3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -25
app.py CHANGED
@@ -1,35 +1,25 @@
1
  from fastapi import FastAPI, Request
2
- from transformers import pipeline
 
3
 
4
  app = FastAPI()
5
 
6
- generator = pipeline("text-generation", model="microsoft/phi-2", max_new_tokens=150)
 
7
 
8
- sessions = {}
9
 
10
- @app.get("/")
11
- def read_root():
12
- return {"message": "LLM API running!"}
 
 
 
 
13
 
14
  @app.post("/chat")
15
  async def chat(request: Request):
16
  data = await request.json()
17
- session_id = data.get("session_id", "default")
18
- user_input = data.get("message", "")
19
- system_prompt = data.get("system_prompt", "You are a helpful assistant.")
20
-
21
- if session_id not in sessions:
22
- sessions[session_id] = system_prompt + "\n"
23
-
24
- sessions[session_id] += f"User: {user_input}\nAssistant:"
25
-
26
- output = generator(sessions[session_id], max_new_tokens=150)[0]["generated_text"]
27
-
28
- if "Assistant:" in output:
29
- assistant_response = output.split("Assistant:")[-1].strip()
30
- else:
31
- assistant_response = output.strip()
32
-
33
- sessions[session_id] += f" {assistant_response}\n"
34
-
35
- return {"response": assistant_response}
 
1
  from fastapi import FastAPI, Request
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
+ import torch
4
 
5
  app = FastAPI()
6
 
7
+ # Load model and tokenizer
8
+ model_id = "meta-llama/Llama-3.1-8B"
9
 
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
11
 
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_id,
14
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
15
+ device_map="auto" # Will auto-detect if CUDA or CPU
16
+ )
17
+
18
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
19
 
20
  @app.post("/chat")
21
  async def chat(request: Request):
22
  data = await request.json()
23
+ prompt = data.get("prompt", "")
24
+ output = pipe(prompt)[0]['generated_text']
25
+ return {"response": output}