moamen270 commited on
Commit
be588c3
·
1 Parent(s): b4ed3fe

Update endpoints.py

Browse files
Files changed (1) hide show
  1. endpoints.py +43 -43
endpoints.py CHANGED
@@ -26,8 +26,8 @@ app.add_middleware(
26
  # response = requests.post(API_URL, headers=headers, json=payload)
27
  # return response.json()
28
 
29
- from fastapi import FastAPI, HTTPException, Body
30
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
31
 
32
 
33
  # model = GPT2LMHeadModel.from_pretrained("EleutherAI/gpt-neo-2.7B")
@@ -37,46 +37,46 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
37
  tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardCoder-1B-V1.0")
38
  model = AutoModelForCausalLM.from_pretrained("WizardLM/WizardCoder-1B-V1.0")
39
 
40
- # pipe = pipeline("text-generation",
41
- # model=base_model,
42
- # tokenizer=tokenizer,
43
- # max_length=4000,
44
- # do_sample=True,
45
- # top_p=0.95,
46
- # repetition_penalty=1.2,
47
- # )
48
- # hf_llm = HuggingFacePipeline(pipeline=pipe)
49
 
50
 
51
- class ChatRequest(BaseModel):
52
- messages: list
53
- temperature: float = 1.0
54
- max_tokens: int = 50
55
- stream: bool = False
56
 
57
- class ChatResponse(BaseModel):
58
- response: str
59
 
60
- @app.post("/v1/chat/completions", response_model=ChatResponse)
61
- async def chat_completions(request: ChatRequest):
62
- try:
63
- # Prepare input prompt
64
- input_prompt = ""
65
- for message in request.messages:
66
- role = message.get('role', 'user')
67
- content = message.get('content', '')
68
- input_prompt += f"{role}: {content}\n"
69
 
70
- # Tokenize and generate response
71
- input_ids = tokenizer.encode(input_prompt, return_tensors='pt')
72
- output = model.generate(input_ids, max_length=1024, temperature=request.temperature, max_tokens=request.max_tokens)
73
 
74
- # Decode and send response
75
- response = tokenizer.decode(output[0], skip_special_tokens=True)
76
- return {"response": response}
77
 
78
- except Exception as e:
79
- raise HTTPException(status_code=500, detail=str(e))
80
 
81
 
82
  @app.get("/")
@@ -89,15 +89,15 @@ def root():
89
  # return {"message": result}
90
 
91
 
92
- # async def askLLM(prompt):
93
- # output = pipe(prompt,do_sample=False)
94
- # return output
95
 
96
- # @app.post("/ask_llm")
97
- # async def ask_llm_endpoint(prompt: str):
98
- # # result = await askLLM(prompt)
99
- # result = pipe(prompt,do_sample=False)
100
- # return {"result": result}
101
 
102
 
103
  # @app.post("/ask_HFAPI")
 
26
  # response = requests.post(API_URL, headers=headers, json=payload)
27
  # return response.json()
28
 
29
+ # from fastapi import FastAPI, HTTPException, Body
30
+ # from transformers import GPT2LMHeadModel, GPT2Tokenizer
31
 
32
 
33
  # model = GPT2LMHeadModel.from_pretrained("EleutherAI/gpt-neo-2.7B")
 
37
  tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardCoder-1B-V1.0")
38
  model = AutoModelForCausalLM.from_pretrained("WizardLM/WizardCoder-1B-V1.0")
39
 
40
+ pipe = pipeline("text-generation",
41
+ model=base_model,
42
+ tokenizer=tokenizer,
43
+ max_length=4000,
44
+ do_sample=True,
45
+ top_p=0.95,
46
+ repetition_penalty=1.2,
47
+ )
48
+ hf_llm = HuggingFacePipeline(pipeline=pipe)
49
 
50
 
51
+ # class ChatRequest(BaseModel):
52
+ # messages: list
53
+ # temperature: float = 1.0
54
+ # max_tokens: int = 50
55
+ # stream: bool = False
56
 
57
+ # class ChatResponse(BaseModel):
58
+ # response: str
59
 
60
+ # @app.post("/v1/chat/completions", response_model=ChatResponse)
61
+ # async def chat_completions(request: ChatRequest):
62
+ # try:
63
+ # # Prepare input prompt
64
+ # input_prompt = ""
65
+ # for message in request.messages:
66
+ # role = message.get('role', 'user')
67
+ # content = message.get('content', '')
68
+ # input_prompt += f"{role}: {content}\n"
69
 
70
+ # # Tokenize and generate response
71
+ # input_ids = tokenizer.encode(input_prompt, return_tensors='pt')
72
+ # output = model.generate(input_ids, max_length=1024, temperature=request.temperature, max_tokens=request.max_tokens)
73
 
74
+ # # Decode and send response
75
+ # response = tokenizer.decode(output[0], skip_special_tokens=True)
76
+ # return {"response": response}
77
 
78
+ # except Exception as e:
79
+ # raise HTTPException(status_code=500, detail=str(e))
80
 
81
 
82
  @app.get("/")
 
89
  # return {"message": result}
90
 
91
 
92
+ async def askLLM(prompt):
93
+ output = pipe(prompt,do_sample=False)
94
+ return output
95
 
96
+ @app.post("/ask_llm")
97
+ async def ask_llm_endpoint(prompt: str):
98
+ # result = await askLLM(prompt)
99
+ result = pipe(prompt,do_sample=False)
100
+ return {"result": result}
101
 
102
 
103
  # @app.post("/ask_HFAPI")