GTee2 commited on
Commit
a909106
verified
1 Parent(s): 46511dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ from threading import Thread
5
+ from collections import defaultdict
6
+ import torch
7
+
8
+ app = FastAPI(title="Mariza Koller 1.5B - API com Mem贸ria 馃槇")
9
+
10
+ print("馃敟 Carregando Qwen2-1.5B-Instruct em int8 na CPU... (aguenta a铆 2-3 min na primeira vez)")
11
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct", trust_remote_code=True)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ "Qwen/Qwen2-1.5B-Instruct",
14
+ device_map="cpu",
15
+ load_in_8bit=True,
16
+ torch_dtype=torch.float16,
17
+ trust_remote_code=True
18
+ )
19
+
20
+ # Cache de conversa em mem贸ria: {user_id: lista de mensagens}
21
+ history_db = defaultdict(list)
22
+ MAX_CONTEXT_TOKENS = 3500
23
+
24
+ @app.get("/")
25
+ async def root():
26
+ return {"message": "Mariza 1.5B t谩 viva e quente na CPU, chefe! 馃槒 manda POST /chat"}
27
+
28
+ @app.post("/chat")
29
+ async def chat(request: Request):
30
+ data = await request.json()
31
+ prompt = data.get("prompt", "").strip()
32
+ user_id = str(data.get("user_id", "default"))
33
+ max_tokens = data.get("max_tokens", 512)
34
+ temperature = data.get("temperature", 0.7)
35
+ stream = data.get("stream", False)
36
+
37
+ if not prompt:
38
+ return JSONResponse({"error": "prompt vazio, seu safado"})
39
+
40
+ # Monta hist贸rico no formato do Qwen2
41
+ messages = history_db[user_id]
42
+ full_prompt = ""
43
+ for role, content in messages:
44
+ full_prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n"
45
+ full_prompt += f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
46
+
47
+ inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=4096)
48
+
49
+ if stream:
50
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
51
+ generation_kwargs = {
52
+ "input_ids": inputs.input_ids,
53
+ "attention_mask": inputs.attention_mask,
54
+ "max_new_tokens": max_tokens,
55
+ "temperature": temperature,
56
+ "do_sample": True,
57
+ "top_p": 0.9,
58
+ "repetition_penalty": 1.1,
59
+ "streamer": streamer
60
+ }
61
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
62
+ thread.start()
63
+
64
+ def generate():
65
+ for new_text in streamer:
66
+ yield new_text
67
+ return StreamingResponse(generate(), media_type="text/event-stream")
68
+
69
+ else:
70
+ outputs = model.generate(
71
+ input_ids=inputs.input_ids,
72
+ attention_mask=inputs.attention_mask,
73
+ max_new_tokens=max_tokens,
74
+ temperature=temperature,
75
+ do_sample=True,
76
+ top_p=0.9,
77
+ repetition_penalty=1.1
78
+ )
79
+ resposta = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
+ resposta = resposta.split("<|im_start|>assistant")[-1].strip()
81
+
82
+ # Salva no hist贸rico
83
+ messages.append(("user", prompt))
84
+ messages.append(("assistant", resposta))
85
+
86
+ # Limpa hist贸rico antigo se passar do limite
87
+ while sum(len(tokenizer.encode(m[1])) for m in messages) > MAX_CONTEXT_TOKENS:
88
+ messages.pop(0)
89
+
90
+ return JSONResponse({"response": resposta, "user_id": user_id})