extraplus commited on
Commit
dde69ed
·
verified ·
1 Parent(s): a868b3c

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +65 -0
main.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI
3
+ from fastapi.responses import StreamingResponse
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
5
+ from threading import Thread
6
+
7
+ app = FastAPI()
8
+
9
+ MODEL_ID = "AshokGakr/model-tiny"
10
+
11
+ print("Loading model...")
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
+
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_ID,
19
+ torch_dtype=torch.float32,
20
+ low_cpu_mem_usage=True
21
+ ).to(device)
22
+
23
+ model.eval()
24
+
25
+ print("Model loaded on", device)
26
+
27
+
28
+ def generate_stream(prompt):
29
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
30
+
31
+ streamer = TextIteratorStreamer(
32
+ tokenizer,
33
+ skip_prompt=True,
34
+ skip_special_tokens=True
35
+ )
36
+
37
+ generation_kwargs = dict(
38
+ **inputs,
39
+ max_new_tokens=120,
40
+ temperature=0.7,
41
+ top_p=0.9,
42
+ repetition_penalty=1.1,
43
+ do_sample=True,
44
+ streamer=streamer
45
+ )
46
+
47
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
48
+ thread.start()
49
+
50
+ for new_text in streamer:
51
+ yield new_text
52
+
53
+
54
+ @app.post("/chat")
55
+ async def chat(data: dict):
56
+ system_prompt = data.get("system", "You are a helpful AI assistant.")
57
+ history = data.get("history", "")
58
+ message = data.get("message", "")
59
+
60
+ full_prompt = f"{system_prompt}\n{history}\nUser: {message}\nAssistant:"
61
+
62
+ return StreamingResponse(
63
+ generate_stream(full_prompt),
64
+ media_type="text/plain"
65
+ )