File size: 5,449 Bytes
118856c
 
 
 
 
 
 
 
 
 
 
 
82f4497
 
 
118856c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01d810b
118856c
 
 
01d810b
118856c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
919d2f7
 
118856c
82f4497
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from typing import List, Tuple
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import threading
import uvicorn
import os
import requests

app = FastAPI()

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load the model and tokenizer from Hugging Face
model_name = "gpt2"  # Replace with a smaller model if needed
token = os.getenv("hugging_face_token")  # Get the token from environment variable

if token:
    tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
    model = AutoModelForCausalLM.from_pretrained(model_name, token=token)
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

# In-memory user store
users = {}
current_user = None

# OAuth2 setup
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

class User(BaseModel):
    username: str
    password: str

class Progress(BaseModel):
    day: int
    completed: bool

class ChatMessage(BaseModel):
    message: str
    history: List[Tuple[str, str]]
    system_message: str
    max_tokens: int
    temperature: float
    top_p: float

def get_current_user(token: str = Depends(oauth2_scheme)):
    if token not in users:
        raise HTTPException(status_code=401, detail="Invalid authentication credentials")
    return users[token]

@app.post("/signup")
def sign_up(user: User):
    if user.username in users:
        raise HTTPException(status_code=400, detail="Username already exists")
    users[user.username] = {"password": user.password, "progress": [False] * 75, "token": user.username}
    return {"message": "User signed up successfully"}

@app.post("/token")
def log_in(form_data: OAuth2PasswordRequestForm = Depends()):
    if form_data.username not in users or users[form_data.username]["password"] != form_data.password:
        raise HTTPException(status_code=400, detail="Invalid username or password")
    return {"access_token": form_data.username, "token_type": "bearer"}

@app.post("/track_progress")
def track_progress(progress: Progress, user: dict = Depends(get_current_user)):
    if progress.day < 1 or progress.day > 75:
        raise HTTPException(status_code=400, detail="Invalid day")
    user["progress"][progress.day - 1] = progress.completed
    return {"message": f"Progress for day {progress.day} updated"}

@app.post("/chat")
def chat(chat_message: ChatMessage, user: dict = Depends(get_current_user)):
    messages = [{"role": "system", "content": chat_message.system_message}]

    for val in chat_message.history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": chat_message.message})

    # Prepare the input for the model
    input_text = ""
    for message in messages:
        input_text += f"{message['role']}: {message['content']}\n"
    
    inputs = tokenizer(input_text, return_tensors="pt")
    outputs = model.generate(
        inputs["input_ids"],
        max_length=chat_message.max_tokens,
        temperature=chat_message.temperature,
        top_p=chat_message.top_p,
        do_sample=True
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Example of how to use user progress in the response
    day = sum(user["progress"]) + 1
    if day <= 75:
        response += f"\nHow did you do on day {day}?"

    return {"response": response}

# Gradio interface for sign-up and login
def sign_up_interface(username, password):
    response = requests.post("/signup", json={"username": username, "password": password})
    return response.json()

def log_in_interface(username, password):
    response = requests.post("/token", data={"username": username, "password": password})
    if response.status_code == 200:
        token = response.json()["access_token"]
        return token
    else:
        return response.json()

def main():
    with gr.Blocks() as demo:
        with gr.Tab("Sign Up"):
            username = gr.Textbox(label="Username")
            password = gr.Textbox(label="Password", type="password")
            sign_up_button = gr.Button("Sign Up")
            sign_up_output = gr.Textbox(label="Sign Up Output")
            sign_up_button.click(sign_up_interface, inputs=[username, password], outputs=sign_up_output)
        
        with gr.Tab("Log In"):
            username = gr.Textbox(label="Username")
            password = gr.Textbox(label="Password", type="password")
            log_in_button = gr.Button("Log In")
            log_in_output = gr.Textbox(label="Log In Output")
            log_in_button.click(log_in_interface, inputs=[username, password], outputs=log_in_output)
        
        # Placeholder for the main application interface
        with gr.Tab("Main Application"):
            gr.Markdown("This is the main application interface. Add your main application components here.")

    demo.launch()

if __name__ == "__main__":
    threading.Thread(target=main).start()
    uvicorn.run(app, host="0.0.0.0", port=8000)