abhlash
updated the code
01d810b
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from typing import List, Tuple
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import threading
import uvicorn
import os
import requests
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the model and tokenizer from Hugging Face
model_name = "gpt2" # Replace with a smaller model if needed
token = os.getenv("hugging_face_token") # Get the token from environment variable
if token:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=token)
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# In-memory user store
users = {}
current_user = None
# OAuth2 setup
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class User(BaseModel):
username: str
password: str
class Progress(BaseModel):
day: int
completed: bool
class ChatMessage(BaseModel):
message: str
history: List[Tuple[str, str]]
system_message: str
max_tokens: int
temperature: float
top_p: float
def get_current_user(token: str = Depends(oauth2_scheme)):
if token not in users:
raise HTTPException(status_code=401, detail="Invalid authentication credentials")
return users[token]
@app.post("/signup")
def sign_up(user: User):
if user.username in users:
raise HTTPException(status_code=400, detail="Username already exists")
users[user.username] = {"password": user.password, "progress": [False] * 75, "token": user.username}
return {"message": "User signed up successfully"}
@app.post("/token")
def log_in(form_data: OAuth2PasswordRequestForm = Depends()):
if form_data.username not in users or users[form_data.username]["password"] != form_data.password:
raise HTTPException(status_code=400, detail="Invalid username or password")
return {"access_token": form_data.username, "token_type": "bearer"}
@app.post("/track_progress")
def track_progress(progress: Progress, user: dict = Depends(get_current_user)):
if progress.day < 1 or progress.day > 75:
raise HTTPException(status_code=400, detail="Invalid day")
user["progress"][progress.day - 1] = progress.completed
return {"message": f"Progress for day {progress.day} updated"}
@app.post("/chat")
def chat(chat_message: ChatMessage, user: dict = Depends(get_current_user)):
messages = [{"role": "system", "content": chat_message.system_message}]
for val in chat_message.history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": chat_message.message})
# Prepare the input for the model
input_text = ""
for message in messages:
input_text += f"{message['role']}: {message['content']}\n"
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(
inputs["input_ids"],
max_length=chat_message.max_tokens,
temperature=chat_message.temperature,
top_p=chat_message.top_p,
do_sample=True
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Example of how to use user progress in the response
day = sum(user["progress"]) + 1
if day <= 75:
response += f"\nHow did you do on day {day}?"
return {"response": response}
# Gradio interface for sign-up and login
def sign_up_interface(username, password):
response = requests.post("/signup", json={"username": username, "password": password})
return response.json()
def log_in_interface(username, password):
response = requests.post("/token", data={"username": username, "password": password})
if response.status_code == 200:
token = response.json()["access_token"]
return token
else:
return response.json()
def main():
with gr.Blocks() as demo:
with gr.Tab("Sign Up"):
username = gr.Textbox(label="Username")
password = gr.Textbox(label="Password", type="password")
sign_up_button = gr.Button("Sign Up")
sign_up_output = gr.Textbox(label="Sign Up Output")
sign_up_button.click(sign_up_interface, inputs=[username, password], outputs=sign_up_output)
with gr.Tab("Log In"):
username = gr.Textbox(label="Username")
password = gr.Textbox(label="Password", type="password")
log_in_button = gr.Button("Log In")
log_in_output = gr.Textbox(label="Log In Output")
log_in_button.click(log_in_interface, inputs=[username, password], outputs=log_in_output)
# Placeholder for the main application interface
with gr.Tab("Main Application"):
gr.Markdown("This is the main application interface. Add your main application components here.")
demo.launch()
if __name__ == "__main__":
threading.Thread(target=main).start()
uvicorn.run(app, host="0.0.0.0", port=8000)