dedlepexa commited on
Commit
34224fe
·
verified ·
1 Parent(s): 81b2953

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -21
app.py CHANGED
@@ -1,6 +1,5 @@
1
- from fastapi import FastAPI
2
  from fastapi.responses import PlainTextResponse
3
- from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
6
 
@@ -13,20 +12,8 @@ model = AutoModelForCausalLM.from_pretrained(model_name)
13
 
14
  model.eval()
15
 
16
- class Message(BaseModel):
17
- message: str
18
 
19
- @app.get("/")
20
- async def root():
21
- return {"status": "TinyLlama работает"}
22
-
23
- from fastapi import Request
24
-
25
- @app.post("/")
26
- async def receive(request: Request):
27
-
28
- data = await request.json()
29
- message = data.get("message", "")
30
 
31
  prompt = f"User: {message}\nAssistant:"
32
 
@@ -35,7 +22,7 @@ async def receive(request: Request):
35
  with torch.no_grad():
36
  outputs = model.generate(
37
  **inputs,
38
- max_new_tokens=20,
39
  do_sample=False
40
  )
41
 
@@ -44,14 +31,37 @@ async def receive(request: Request):
44
  if "Assistant:" in reply:
45
  reply = reply.split("Assistant:")[-1].strip()
46
 
47
- return PlainTextResponse(reply)
48
 
49
- reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
 
51
- if "Assistant:" in reply:
52
- reply = reply.split("Assistant:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- return PlainTextResponse(reply)
 
55
 
56
 
57
  if __name__ == "__main__":
 
1
+ from fastapi import FastAPI, Request
2
  from fastapi.responses import PlainTextResponse
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
 
12
 
13
  model.eval()
14
 
 
 
15
 
16
+ def generate_ai(message: str):
 
 
 
 
 
 
 
 
 
 
17
 
18
  prompt = f"User: {message}\nAssistant:"
19
 
 
22
  with torch.no_grad():
23
  outputs = model.generate(
24
  **inputs,
25
+ max_new_tokens=25,
26
  do_sample=False
27
  )
28
 
 
31
  if "Assistant:" in reply:
32
  reply = reply.split("Assistant:")[-1].strip()
33
 
34
+ return reply
35
 
 
36
 
37
+ @app.get("/")
38
+ async def root():
39
+ return PlainTextResponse("AI server работает")
40
+
41
+
42
+ @app.api_route("/", methods=["GET","POST","PUT","PATCH","DELETE","HEAD"])
43
+ async def universal(request: Request):
44
+
45
+ try:
46
+
47
+ # пробуем получить JSON
48
+ try:
49
+ data = await request.json()
50
+ message = data.get("message", "")
51
+ except:
52
+ # если не JSON — читаем обычный текст
53
+ body = await request.body()
54
+ message = body.decode("utf-8")
55
+
56
+ if not message:
57
+ message = "Hello"
58
+
59
+ reply = generate_ai(message)
60
+
61
+ return PlainTextResponse(reply)
62
 
63
+ except Exception as e:
64
+ return PlainTextResponse(f"ERROR: {str(e)}")
65
 
66
 
67
  if __name__ == "__main__":