ProfessorCEO commited on
Commit
8613805
·
verified ·
1 Parent(s): 785d8df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -12
app.py CHANGED
@@ -2,6 +2,7 @@ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
 
5
 
6
  app = FastAPI()
7
 
@@ -14,6 +15,15 @@ class CodeRequest(BaseModel):
14
  language: str = "python"
15
  max_tokens: int = 128
16
 
 
 
 
 
 
 
 
 
 
17
  @app.get("/")
18
  def root():
19
  return {"status": "DevOS AI is running"}
@@ -22,17 +32,25 @@ def root():
22
  def complete_code(request: CodeRequest):
23
  prompt = f"Continue the following {request.language} code:\n{request.code}"
24
  inputs = tokenizer(prompt, return_tensors="pt")
25
-
26
  with torch.no_grad():
27
- outputs = model.generate(
28
- **inputs,
29
- max_new_tokens=request.max_tokens,
30
- temperature=0.2,
31
- do_sample=True,
32
- pad_token_id=tokenizer.eos_token_id
33
- )
34
-
 
 
 
 
 
 
 
 
 
 
35
  generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
- suggestion = generated[len(prompt):]
37
-
38
- return {"suggestion": suggestion.strip()}
 
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
+ from typing import List, Optional
6
 
7
  app = FastAPI()
8
 
 
15
  language: str = "python"
16
  max_tokens: int = 128
17
 
18
+ class ChatMessage(BaseModel):
19
+ role: str
20
+ content: str
21
+
22
+ class ChatRequest(BaseModel):
23
+ messages: List[ChatMessage]
24
+ system: Optional[str] = ""
25
+ max_tokens: int = 1024
26
+
27
  @app.get("/")
28
  def root():
29
  return {"status": "DevOS AI is running"}
 
32
  def complete_code(request: CodeRequest):
33
  prompt = f"Continue the following {request.language} code:\n{request.code}"
34
  inputs = tokenizer(prompt, return_tensors="pt")
 
35
  with torch.no_grad():
36
+ outputs = model.generate(**inputs, max_new_tokens=request.max_tokens,
37
+ temperature=0.2, do_sample=True, pad_token_id=tokenizer.eos_token_id)
38
+ generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
+ return {"suggestion": generated[len(prompt):].strip()}
40
+
41
+ @app.post("/chat")
42
+ def chat(request: ChatRequest):
43
+ # Build conversation prompt
44
+ prompt = request.system + "\n\n" if request.system else ""
45
+ for msg in request.messages[-8:]: # last 8 messages for context
46
+ role = "User" if msg.role == "user" else "DevOS AI"
47
+ prompt += f"{role}: {msg.content}\n"
48
+ prompt += "DevOS AI:"
49
+
50
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
51
+ with torch.no_grad():
52
+ outputs = model.generate(**inputs, max_new_tokens=request.max_tokens,
53
+ temperature=0.4, do_sample=True, pad_token_id=tokenizer.eos_token_id)
54
  generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
+ reply = generated[len(prompt):].strip()
56
+ return {"reply": reply}